Commit c9c259cc by Olli Etuaho Committed by Commit Bot

Add a shared traverse() function for most node types

The traversal logic for many node types is essentially the same. Use a single traverse() function for all simple node types instead of having different ones for each node type. Special traversal code is only needed for those node types where the traversal logic is overridden in specific traversers or which do special bookkeeping. This makes traverser behavior a bit more consistent: InVisit calls are now done for all node types, including if/else, ternary and loop nodes. Also false returned from visit function will always skip traversing the next children of that node. This reduces shader_translator binary size on Windows by 8 kilobytes. The added helper functions will also make it easier to implement alternative more efficient traversers. Unfortunately this also regresses compiler perf tests by around 2-3%. BUG=angleproject:2662 TEST=angle_unittests, angle_end2end_tests Change-Id: I3cb1256297b66e1db4b133b8fb84a24c349a9e29 Reviewed-on: https://chromium-review.googlesource.com/1133009Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 5598148b
......@@ -197,6 +197,7 @@
'compiler/translator/tree_util/ReplaceVariable.h',
'compiler/translator/tree_util/RunAtTheEndOfShader.cpp',
'compiler/translator/tree_util/RunAtTheEndOfShader.h',
'compiler/translator/tree_util/Visit.h',
'third_party/compiler/ArrayBoundsClamper.cpp',
'third_party/compiler/ArrayBoundsClamper.h',
],
......
......@@ -202,6 +202,61 @@ void TIntermExpression::setTypePreservePrecision(const TType &t)
return true; \
}
unsigned int TIntermSymbol::getChildCount()
{
return 0;
}
TIntermNode *TIntermSymbol::getChildNode(unsigned int index)
{
UNREACHABLE();
return nullptr;
}
unsigned int TIntermConstantUnion::getChildCount()
{
return 0;
}
TIntermNode *TIntermConstantUnion::getChildNode(unsigned int index)
{
UNREACHABLE();
return nullptr;
}
unsigned int TIntermLoop::getChildCount()
{
return (mInit ? 1 : 0) + (mCond ? 1 : 0) + (mExpr ? 1 : 0) + (mBody ? 1 : 0);
}
TIntermNode *TIntermLoop::getChildNode(unsigned int index)
{
TIntermNode *children[4];
unsigned int childIndex = 0;
if (mInit)
{
children[childIndex] = mInit;
++childIndex;
}
if (mCond)
{
children[childIndex] = mCond;
++childIndex;
}
if (mExpr)
{
children[childIndex] = mExpr;
++childIndex;
}
if (mBody)
{
children[childIndex] = mBody;
++childIndex;
}
ASSERT(index < childIndex);
return children[index];
}
bool TIntermLoop::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
ASSERT(original != nullptr); // This risks replacing multiple children.
......@@ -212,12 +267,36 @@ bool TIntermLoop::replaceChildNode(TIntermNode *original, TIntermNode *replaceme
return false;
}
unsigned int TIntermBranch::getChildCount()
{
return (mExpression ? 1 : 0);
}
TIntermNode *TIntermBranch::getChildNode(unsigned int index)
{
ASSERT(mExpression);
ASSERT(index == 0);
return mExpression;
}
bool TIntermBranch::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mExpression, TIntermTyped, original, replacement);
return false;
}
unsigned int TIntermSwizzle::getChildCount()
{
return 1;
}
TIntermNode *TIntermSwizzle::getChildNode(unsigned int index)
{
ASSERT(mOperand);
ASSERT(index == 0);
return mOperand;
}
bool TIntermSwizzle::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
ASSERT(original->getAsTyped()->getType() == replacement->getAsTyped()->getType());
......@@ -225,6 +304,21 @@ bool TIntermSwizzle::replaceChildNode(TIntermNode *original, TIntermNode *replac
return false;
}
unsigned int TIntermBinary::getChildCount()
{
return 2;
}
TIntermNode *TIntermBinary::getChildNode(unsigned int index)
{
ASSERT(index < 2);
if (index == 0)
{
return mLeft;
}
return mRight;
}
bool TIntermBinary::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mLeft, TIntermTyped, original, replacement);
......@@ -232,6 +326,18 @@ bool TIntermBinary::replaceChildNode(TIntermNode *original, TIntermNode *replace
return false;
}
unsigned int TIntermUnary::getChildCount()
{
return 1;
}
TIntermNode *TIntermUnary::getChildNode(unsigned int index)
{
ASSERT(mOperand);
ASSERT(index == 0);
return mOperand;
}
bool TIntermUnary::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
ASSERT(original->getAsTyped()->getType() == replacement->getAsTyped()->getType());
......@@ -239,12 +345,39 @@ bool TIntermUnary::replaceChildNode(TIntermNode *original, TIntermNode *replacem
return false;
}
unsigned int TIntermInvariantDeclaration::getChildCount()
{
return 1;
}
TIntermNode *TIntermInvariantDeclaration::getChildNode(unsigned int index)
{
ASSERT(mSymbol);
ASSERT(index == 0);
return mSymbol;
}
bool TIntermInvariantDeclaration::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mSymbol, TIntermSymbol, original, replacement);
return false;
}
unsigned int TIntermFunctionDefinition::getChildCount()
{
return 2;
}
TIntermNode *TIntermFunctionDefinition::getChildNode(unsigned int index)
{
ASSERT(index < 2);
if (index == 0)
{
return mPrototype;
}
return mBody;
}
bool TIntermFunctionDefinition::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mPrototype, TIntermFunctionPrototype, original, replacement);
......@@ -252,21 +385,62 @@ bool TIntermFunctionDefinition::replaceChildNode(TIntermNode *original, TIntermN
return false;
}
unsigned int TIntermAggregate::getChildCount()
{
return mArguments.size();
}
TIntermNode *TIntermAggregate::getChildNode(unsigned int index)
{
return mArguments[index];
}
bool TIntermAggregate::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
return replaceChildNodeInternal(original, replacement);
}
unsigned int TIntermBlock::getChildCount()
{
return mStatements.size();
}
TIntermNode *TIntermBlock::getChildNode(unsigned int index)
{
return mStatements[index];
}
bool TIntermBlock::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
return replaceChildNodeInternal(original, replacement);
}
unsigned int TIntermFunctionPrototype::getChildCount()
{
return 0;
}
TIntermNode *TIntermFunctionPrototype::getChildNode(unsigned int index)
{
UNREACHABLE();
return nullptr;
}
bool TIntermFunctionPrototype::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
return false;
}
unsigned int TIntermDeclaration::getChildCount()
{
return mDeclarators.size();
}
TIntermNode *TIntermDeclaration::getChildNode(unsigned int index)
{
return mDeclarators[index];
}
bool TIntermDeclaration::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
return replaceChildNodeInternal(original, replacement);
......@@ -700,6 +874,25 @@ void TIntermDeclaration::appendDeclarator(TIntermTyped *declarator)
mDeclarators.push_back(declarator);
}
unsigned int TIntermTernary::getChildCount()
{
return 3;
}
TIntermNode *TIntermTernary::getChildNode(unsigned int index)
{
ASSERT(index < 3);
if (index == 0)
{
return mCondition;
}
if (index == 1)
{
return mTrueExpression;
}
return mFalseExpression;
}
bool TIntermTernary::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mCondition, TIntermTyped, original, replacement);
......@@ -708,6 +901,24 @@ bool TIntermTernary::replaceChildNode(TIntermNode *original, TIntermNode *replac
return false;
}
unsigned int TIntermIfElse::getChildCount()
{
return 1 + (mTrueBlock ? 1 : 0) + (mFalseBlock ? 1 : 0);
}
TIntermNode *TIntermIfElse::getChildNode(unsigned int index)
{
if (index == 0)
{
return mCondition;
}
if (mTrueBlock && index == 1)
{
return mTrueBlock;
}
return mFalseBlock;
}
bool TIntermIfElse::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mCondition, TIntermTyped, original, replacement);
......@@ -716,6 +927,21 @@ bool TIntermIfElse::replaceChildNode(TIntermNode *original, TIntermNode *replace
return false;
}
unsigned int TIntermSwitch::getChildCount()
{
return 2;
}
TIntermNode *TIntermSwitch::getChildNode(unsigned int index)
{
ASSERT(index < 2);
if (index == 0)
{
return mInit;
}
return mStatementList;
}
bool TIntermSwitch::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mInit, TIntermTyped, original, replacement);
......@@ -724,6 +950,18 @@ bool TIntermSwitch::replaceChildNode(TIntermNode *original, TIntermNode *replace
return false;
}
unsigned int TIntermCase::getChildCount()
{
return (mCondition ? 1 : 0);
}
TIntermNode *TIntermCase::getChildNode(unsigned int index)
{
ASSERT(index == 0);
ASSERT(mCondition);
return mCondition;
}
bool TIntermCase::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mCondition, TIntermTyped, original, replacement);
......@@ -1117,6 +1355,7 @@ TIntermLoop::TIntermLoop(TLoopType type,
TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlock *falseB)
: TIntermNode(), mCondition(cond), mTrueBlock(trueB), mFalseBlock(falseB)
{
ASSERT(mCondition);
// Prune empty false blocks so that there won't be unnecessary operations done on it.
if (mFalseBlock && mFalseBlock->getSequence()->empty())
{
......@@ -1127,6 +1366,7 @@ TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlo
TIntermSwitch::TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList)
: TIntermNode(), mInit(init), mStatementList(statementList)
{
ASSERT(mInit);
ASSERT(mStatementList);
}
......
......@@ -27,6 +27,7 @@
#include "compiler/translator/Operator.h"
#include "compiler/translator/SymbolUniqueId.h"
#include "compiler/translator/Types.h"
#include "compiler/translator/tree_util/Visit.h"
namespace sh
{
......@@ -80,7 +81,9 @@ class TIntermNode : angle::NonCopyable
const TSourceLoc &getLine() const { return mLine; }
void setLine(const TSourceLoc &l) { mLine = l; }
virtual void traverse(TIntermTraverser *) = 0;
virtual void traverse(TIntermTraverser *it);
virtual bool visit(Visit visit, TIntermTraverser *it) = 0;
virtual TIntermTyped *getAsTyped() { return 0; }
virtual TIntermConstantUnion *getAsConstantUnion() { return 0; }
virtual TIntermFunctionDefinition *getAsFunctionDefinition() { return nullptr; }
......@@ -100,6 +103,8 @@ class TIntermNode : angle::NonCopyable
virtual TIntermLoop *getAsLoopNode() { return 0; }
virtual TIntermBranch *getAsBranchNode() { return 0; }
virtual unsigned int getChildCount() = 0;
virtual TIntermNode *getChildNode(unsigned int index) = 0;
// Replace a child node. Return true if |original| is a child
// node and it is replaced; otherwise, return false.
virtual bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) = 0;
......@@ -187,7 +192,11 @@ class TIntermLoop : public TIntermNode
TIntermBlock *body);
TIntermLoop *getAsLoopNode() override { return this; }
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TLoopType getType() const { return mType; }
......@@ -217,8 +226,11 @@ class TIntermBranch : public TIntermNode
public:
TIntermBranch(TOperator op, TIntermTyped *e) : mFlowOp(op), mExpression(e) {}
void traverse(TIntermTraverser *it) override;
TIntermBranch *getAsBranchNode() override { return this; }
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TOperator getFlowOp() { return mFlowOp; }
......@@ -250,8 +262,12 @@ class TIntermSymbol : public TIntermTyped
ImmutableString getName() const;
const TVariable &variable() const { return *mVariable; }
void traverse(TIntermTraverser *it) override;
TIntermSymbol *getAsSymbolNode() override { return this; }
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; }
private:
......@@ -318,7 +334,11 @@ class TIntermConstantUnion : public TIntermExpression
}
TIntermConstantUnion *getAsConstantUnion() override { return this; }
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; }
TConstantUnion *foldUnaryNonComponentWise(TOperator op);
......@@ -388,7 +408,10 @@ class TIntermSwizzle : public TIntermExpression
TIntermTyped *deepCopy() const override { return new TIntermSwizzle(*this); }
TIntermSwizzle *getAsSwizzleNode() override { return this; };
void traverse(TIntermTraverser *it) override;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
bool hasSideEffects() const override { return mOperand->hasSideEffects(); }
......@@ -433,7 +456,11 @@ class TIntermBinary : public TIntermOperator
static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
TIntermBinary *getAsBinaryNode() override { return this; };
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
bool hasSideEffects() const override
......@@ -478,8 +505,12 @@ class TIntermUnary : public TIntermOperator
TIntermTyped *deepCopy() const override { return new TIntermUnary(*this); }
void traverse(TIntermTraverser *it) override;
TIntermUnary *getAsUnaryNode() override { return this; }
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
bool hasSideEffects() const override { return isAssignment() || mOperand->hasSideEffects(); }
......@@ -555,7 +586,11 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
const TConstantUnion *getConstantValue() const override;
TIntermAggregate *getAsAggregate() override { return this; }
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
bool hasSideEffects() const override;
......@@ -620,7 +655,11 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase
~TIntermBlock() {}
TIntermBlock *getAsBlock() override { return this; }
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
// Only intended for initially building the block.
......@@ -642,7 +681,11 @@ class TIntermFunctionPrototype : public TIntermTyped
~TIntermFunctionPrototype() {}
TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; }
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
const TType &getType() const override;
......@@ -677,7 +720,11 @@ class TIntermFunctionDefinition : public TIntermNode
}
TIntermFunctionDefinition *getAsFunctionDefinition() override { return this; }
void traverse(TIntermTraverser *it) override;
void traverse(TIntermTraverser *it) final;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermFunctionPrototype *getFunctionPrototype() const { return mPrototype; }
......@@ -698,7 +745,10 @@ class TIntermDeclaration : public TIntermNode, public TIntermAggregateBase
~TIntermDeclaration() {}
TIntermDeclaration *getAsDeclarationNode() override { return this; }
void traverse(TIntermTraverser *it) override;
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
// Only intended for initially building the declaration.
......@@ -719,10 +769,12 @@ class TIntermInvariantDeclaration : public TIntermNode
TIntermInvariantDeclaration(TIntermSymbol *symbol, const TSourceLoc &line);
virtual TIntermInvariantDeclaration *getAsInvariantDeclarationNode() override { return this; }
bool visit(Visit visit, TIntermTraverser *it) final;
TIntermSymbol *getSymbol() { return mSymbol; }
void traverse(TIntermTraverser *it) override;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
private:
......@@ -735,13 +787,16 @@ class TIntermTernary : public TIntermExpression
public:
TIntermTernary(TIntermTyped *cond, TIntermTyped *trueExpression, TIntermTyped *falseExpression);
void traverse(TIntermTraverser *it) override;
TIntermTernary *getAsTernaryNode() override { return this; }
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *getCondition() const { return mCondition; }
TIntermTyped *getTrueExpression() const { return mTrueExpression; }
TIntermTyped *getFalseExpression() const { return mFalseExpression; }
TIntermTernary *getAsTernaryNode() override { return this; }
TIntermTyped *deepCopy() const override { return new TIntermTernary(*this); }
......@@ -770,13 +825,16 @@ class TIntermIfElse : public TIntermNode
public:
TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlock *falseB);
void traverse(TIntermTraverser *it) override;
TIntermIfElse *getAsIfElseNode() override { return this; }
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *getCondition() const { return mCondition; }
TIntermBlock *getTrueBlock() const { return mTrueBlock; }
TIntermBlock *getFalseBlock() const { return mFalseBlock; }
TIntermIfElse *getAsIfElseNode() override { return this; }
protected:
TIntermTyped *mCondition;
......@@ -792,10 +850,12 @@ class TIntermSwitch : public TIntermNode
public:
TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList);
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermSwitch *getAsSwitchNode() override { return this; }
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *getInit() { return mInit; }
TIntermBlock *getStatementList() { return mStatementList; }
......@@ -816,10 +876,12 @@ class TIntermCase : public TIntermNode
public:
TIntermCase(TIntermTyped *condition) : TIntermNode(), mCondition(condition) {}
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermCase *getAsCaseNode() override { return this; }
bool visit(Visit visit, TIntermTraverser *it) final;
unsigned int getChildCount() final;
TIntermNode *getChildNode(unsigned int index) final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
bool hasCondition() const { return mCondition != nullptr; }
TIntermTyped *getCondition() const { return mCondition; }
......
......@@ -13,19 +13,64 @@
namespace sh
{
// Traverse the intermediate representation tree, and call a node type specific visit function for
// each node. Traversal is done recursively through the node member function traverse(). Nodes with
// children can have their whole subtree skipped if preVisit is turned on and the type specific
// function returns false.
template <typename T>
void TIntermTraverser::traverse(T *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
// Visit the node before children if pre-visiting.
if (preVisit)
visit = node->visit(PreVisit, this);
if (visit)
{
unsigned int childIndex = 0;
unsigned int childCount = node->getChildCount();
while (childIndex < childCount && visit)
{
node->getChildNode(childIndex)->traverse(this);
if (inVisit && childIndex != childCount - 1)
{
visit = node->visit(InVisit, this);
}
++childIndex;
}
if (visit && postVisit)
node->visit(PostVisit, this);
}
}
void TIntermNode::traverse(TIntermTraverser *it)
{
it->traverse(this);
}
void TIntermSymbol::traverse(TIntermTraverser *it)
{
it->traverseSymbol(this);
TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
it->visitSymbol(this);
}
void TIntermConstantUnion::traverse(TIntermTraverser *it)
{
it->traverseConstantUnion(this);
TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
it->visitConstantUnion(this);
}
void TIntermSwizzle::traverse(TIntermTraverser *it)
void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
{
it->traverseSwizzle(this);
TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
it->visitFunctionPrototype(this);
}
void TIntermBinary::traverse(TIntermTraverser *it)
......@@ -38,64 +83,112 @@ void TIntermUnary::traverse(TIntermTraverser *it)
it->traverseUnary(this);
}
void TIntermTernary::traverse(TIntermTraverser *it)
void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
{
it->traverseTernary(this);
it->traverseFunctionDefinition(this);
}
void TIntermIfElse::traverse(TIntermTraverser *it)
void TIntermBlock::traverse(TIntermTraverser *it)
{
it->traverseIfElse(this);
it->traverseBlock(this);
}
void TIntermSwitch::traverse(TIntermTraverser *it)
void TIntermAggregate::traverse(TIntermTraverser *it)
{
it->traverseSwitch(this);
it->traverseAggregate(this);
}
void TIntermCase::traverse(TIntermTraverser *it)
void TIntermLoop::traverse(TIntermTraverser *it)
{
it->traverseCase(this);
it->traverseLoop(this);
}
void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
bool TIntermSymbol::visit(Visit visit, TIntermTraverser *it)
{
it->traverseFunctionDefinition(this);
it->visitSymbol(this);
return false;
}
void TIntermBlock::traverse(TIntermTraverser *it)
bool TIntermConstantUnion::visit(Visit visit, TIntermTraverser *it)
{
it->traverseBlock(this);
it->visitConstantUnion(this);
return false;
}
void TIntermInvariantDeclaration::traverse(TIntermTraverser *it)
bool TIntermFunctionPrototype::visit(Visit visit, TIntermTraverser *it)
{
it->traverseInvariantDeclaration(this);
it->visitFunctionPrototype(this);
return false;
}
void TIntermDeclaration::traverse(TIntermTraverser *it)
bool TIntermFunctionDefinition::visit(Visit visit, TIntermTraverser *it)
{
it->traverseDeclaration(this);
return it->visitFunctionDefinition(visit, this);
}
void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
bool TIntermUnary::visit(Visit visit, TIntermTraverser *it)
{
it->traverseFunctionPrototype(this);
return it->visitUnary(visit, this);
}
void TIntermAggregate::traverse(TIntermTraverser *it)
bool TIntermSwizzle::visit(Visit visit, TIntermTraverser *it)
{
it->traverseAggregate(this);
return it->visitSwizzle(visit, this);
}
void TIntermLoop::traverse(TIntermTraverser *it)
bool TIntermBinary::visit(Visit visit, TIntermTraverser *it)
{
it->traverseLoop(this);
return it->visitBinary(visit, this);
}
void TIntermBranch::traverse(TIntermTraverser *it)
bool TIntermTernary::visit(Visit visit, TIntermTraverser *it)
{
it->traverseBranch(this);
return it->visitTernary(visit, this);
}
bool TIntermAggregate::visit(Visit visit, TIntermTraverser *it)
{
return it->visitAggregate(visit, this);
}
bool TIntermDeclaration::visit(Visit visit, TIntermTraverser *it)
{
return it->visitDeclaration(visit, this);
}
bool TIntermInvariantDeclaration::visit(Visit visit, TIntermTraverser *it)
{
return it->visitInvariantDeclaration(visit, this);
}
bool TIntermBlock::visit(Visit visit, TIntermTraverser *it)
{
return it->visitBlock(visit, this);
}
bool TIntermIfElse::visit(Visit visit, TIntermTraverser *it)
{
return it->visitIfElse(visit, this);
}
bool TIntermLoop::visit(Visit visit, TIntermTraverser *it)
{
return it->visitLoop(visit, this);
}
bool TIntermBranch::visit(Visit visit, TIntermTraverser *it)
{
return it->visitBranch(visit, this);
}
bool TIntermSwitch::visit(Visit visit, TIntermTraverser *it)
{
return it->visitSwitch(visit, this);
}
bool TIntermCase::visit(Visit visit, TIntermTraverser *it)
{
return it->visitCase(visit, this);
}
TIntermTraverser::TIntermTraverser(bool preVisit,
......@@ -110,6 +203,8 @@ TIntermTraverser::TIntermTraverser(bool preVisit,
mInGlobalScope(true),
mSymbolTable(symbolTable)
{
// Only enabling inVisit is not supported.
ASSERT(!(inVisit && !preVisit && !postVisit));
}
TIntermTraverser::~TIntermTraverser()
......@@ -186,89 +281,9 @@ bool TLValueTrackingTraverser::isInFunctionCallOutParameter() const
return mInFunctionCallOutParameter;
}
//
// Traverse the intermediate representation tree, and
// call a node type specific function for each node.
// Done recursively through the member function Traverse().
// Node types can be skipped if their function to call is 0,
// but their subtree will still be traversed.
// Nodes with children can have their whole subtree skipped
// if preVisit is turned on and the type specific function
// returns false.
//
//
// Traversal functions for terminals are straighforward....
//
void TIntermTraverser::traverseSymbol(TIntermSymbol *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
visitSymbol(node);
}
void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
visitConstantUnion(node);
}
void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitSwizzle(PreVisit, node);
if (visit)
{
node->getOperand()->traverse(this);
}
if (visit && postVisit)
visitSwizzle(PostVisit, node);
}
//
// Traverse a binary node.
//
void TIntermTraverser::traverseBinary(TIntermBinary *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
//
// visit the node before children if pre-visiting.
//
if (preVisit)
visit = visitBinary(PreVisit, node);
//
// Visit the children, in the right order.
//
if (visit)
{
node->getLeft()->traverse(this);
if (inVisit)
visit = visitBinary(InVisit, node);
if (visit)
node->getRight()->traverse(this);
}
//
// Visit the node after the children, if requested and the traversal
// hasn't been cancelled yet.
//
if (visit && postVisit)
visitBinary(PostVisit, node);
traverse(node);
}
void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
......@@ -279,21 +294,13 @@ void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
bool visit = true;
//
// visit the node before children if pre-visiting.
//
if (preVisit)
visit = visitBinary(PreVisit, node);
visit = node->visit(PreVisit, this);
//
// Visit the children, in the right order.
//
if (visit)
{
// Some binary operations like indexing can be inside an expression which must be an
// l-value.
bool parentOperatorRequiresLValue = operatorRequiresLValue();
bool parentInFunctionCallOutParameter = isInFunctionCallOutParameter();
if (node->isAssignment())
{
ASSERT(!isLValueRequiredHere());
......@@ -302,58 +309,45 @@ void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
node->getLeft()->traverse(this);
if (inVisit)
visit = visitBinary(InVisit, node);
if (node->isAssignment())
setOperatorRequiresLValue(false);
// Index is not required to be an l-value even when the surrounding expression is required
// to be an l-value.
TOperator op = node->getOp();
if (op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
op == EOpIndexDirectStruct || op == EOpIndexIndirect)
{
setOperatorRequiresLValue(false);
setInFunctionCallOutParameter(false);
}
if (inVisit)
visit = node->visit(InVisit, this);
if (visit)
{
// Some binary operations like indexing can be inside an expression which must be an
// l-value.
bool parentOperatorRequiresLValue = operatorRequiresLValue();
bool parentInFunctionCallOutParameter = isInFunctionCallOutParameter();
// Index is not required to be an l-value even when the surrounding expression is
// required to be an l-value.
TOperator op = node->getOp();
if (op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
op == EOpIndexDirectStruct || op == EOpIndexIndirect)
{
setOperatorRequiresLValue(false);
setInFunctionCallOutParameter(false);
}
node->getRight()->traverse(this);
setOperatorRequiresLValue(parentOperatorRequiresLValue);
setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
}
setOperatorRequiresLValue(parentOperatorRequiresLValue);
setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
//
// Visit the node after the children, if requested and the traversal
// hasn't been cancelled yet.
//
if (visit && postVisit)
visitBinary(PostVisit, node);
// Visit the node after the children, if requested and the traversal
// hasn't been cancelled yet.
if (postVisit)
visit = node->visit(PostVisit, this);
}
}
}
//
// Traverse a unary node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseUnary(TIntermUnary *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitUnary(PreVisit, node);
if (visit)
{
node->getOperand()->traverse(this);
}
if (visit && postVisit)
visitUnary(PostVisit, node);
traverse(node);
}
void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
......@@ -365,7 +359,7 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
bool visit = true;
if (preVisit)
visit = visitUnary(PreVisit, node);
visit = node->visit(PreVisit, this);
if (visit)
{
......@@ -385,13 +379,13 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
node->getOperand()->traverse(this);
setOperatorRequiresLValue(false);
}
if (visit && postVisit)
visitUnary(PostVisit, node);
if (postVisit)
visit = node->visit(PostVisit, this);
}
}
// Traverse a function definition node.
// Traverse a function definition node. This keeps track of global scope.
void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
......@@ -401,25 +395,26 @@ void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *nod
bool visit = true;
if (preVisit)
visit = visitFunctionDefinition(PreVisit, node);
visit = node->visit(PreVisit, this);
if (visit)
{
mInGlobalScope = false;
node->getFunctionPrototype()->traverse(this);
if (inVisit)
visit = visitFunctionDefinition(InVisit, node);
node->getBody()->traverse(this);
mInGlobalScope = true;
visit = node->visit(InVisit, this);
if (visit)
{
mInGlobalScope = false;
node->getBody()->traverse(this);
mInGlobalScope = true;
if (postVisit)
visit = node->visit(PostVisit, this);
}
}
if (visit && postVisit)
visitFunctionDefinition(PostVisit, node);
}
// Traverse a block node.
// Traverse a block node. This keeps track of the position of traversed child nodes within the block
// so that nodes may be inserted before or after them.
void TIntermTraverser::traverseBlock(TIntermBlock *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
......@@ -433,119 +428,35 @@ void TIntermTraverser::traverseBlock(TIntermBlock *node)
TIntermSequence *sequence = node->getSequence();
if (preVisit)
visit = visitBlock(PreVisit, node);
visit = node->visit(PreVisit, this);
if (visit)
{
for (auto *child : *sequence)
{
child->traverse(this);
if (visit && inVisit)
if (visit)
{
if (child != sequence->back())
visit = visitBlock(InVisit, node);
child->traverse(this);
if (inVisit)
{
if (child != sequence->back())
visit = node->visit(InVisit, this);
}
incrementParentBlockPos();
}
incrementParentBlockPos();
}
}
if (visit && postVisit)
visitBlock(PostVisit, node);
popParentBlock();
}
void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
{
visit = visitInvariantDeclaration(PreVisit, node);
}
if (visit)
{
node->getSymbol()->traverse(this);
if (postVisit)
{
visitInvariantDeclaration(PostVisit, node);
}
}
}
// Traverse a declaration node.
void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
TIntermSequence *sequence = node->getSequence();
if (preVisit)
visit = visitDeclaration(PreVisit, node);
if (visit)
{
for (auto *child : *sequence)
{
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitDeclaration(InVisit, node);
}
}
if (visit && postVisit)
visit = node->visit(PostVisit, this);
}
if (visit && postVisit)
visitDeclaration(PostVisit, node);
}
void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
visitFunctionPrototype(node);
popParentBlock();
}
// Traverse an aggregate node. Same comments in binary node apply here.
void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
TIntermSequence *sequence = node->getSequence();
if (preVisit)
visit = visitAggregate(PreVisit, node);
if (visit)
{
for (auto *child : *sequence)
{
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitAggregate(InVisit, node);
}
}
}
if (visit && postVisit)
visitAggregate(PostVisit, node);
traverse(node);
}
bool TIntermTraverser::CompareInsertion(const NodeInsertMultipleEntry &a,
......@@ -661,195 +572,47 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
TIntermSequence *sequence = node->getSequence();
if (preVisit)
visit = visitAggregate(PreVisit, node);
visit = node->visit(PreVisit, this);
if (visit)
{
size_t paramIndex = 0u;
for (auto *child : *sequence)
{
if (node->getFunction())
if (visit)
{
// Both built-ins and user defined functions should have the function symbol set.
ASSERT(paramIndex < node->getFunction()->getParamCount());
TQualifier qualifier =
node->getFunction()->getParam(paramIndex)->getType().getQualifier();
setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
++paramIndex;
}
else
{
ASSERT(node->isConstructor());
}
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitAggregate(InVisit, node);
if (node->getFunction())
{
// Both built-ins and user defined functions should have the function symbol
// set.
ASSERT(paramIndex < node->getFunction()->getParamCount());
TQualifier qualifier =
node->getFunction()->getParam(paramIndex)->getType().getQualifier();
setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
++paramIndex;
}
else
{
ASSERT(node->isConstructor());
}
child->traverse(this);
if (inVisit)
{
if (child != sequence->back())
visit = node->visit(InVisit, this);
}
}
}
setInFunctionCallOutParameter(false);
}
if (visit && postVisit)
visitAggregate(PostVisit, node);
}
//
// Traverse a ternary node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseTernary(TIntermTernary *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitTernary(PreVisit, node);
if (visit)
{
node->getCondition()->traverse(this);
node->getTrueExpression()->traverse(this);
node->getFalseExpression()->traverse(this);
if (visit && postVisit)
visit = node->visit(PostVisit, this);
}
if (visit && postVisit)
visitTernary(PostVisit, node);
}
// Traverse an if-else node. Same comments in binary node apply here.
void TIntermTraverser::traverseIfElse(TIntermIfElse *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitIfElse(PreVisit, node);
if (visit)
{
node->getCondition()->traverse(this);
if (node->getTrueBlock())
node->getTrueBlock()->traverse(this);
if (node->getFalseBlock())
node->getFalseBlock()->traverse(this);
}
if (visit && postVisit)
visitIfElse(PostVisit, node);
}
//
// Traverse a switch node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseSwitch(TIntermSwitch *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitSwitch(PreVisit, node);
if (visit)
{
node->getInit()->traverse(this);
if (inVisit)
visit = visitSwitch(InVisit, node);
if (visit && node->getStatementList())
node->getStatementList()->traverse(this);
}
if (visit && postVisit)
visitSwitch(PostVisit, node);
}
//
// Traverse a case node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseCase(TIntermCase *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitCase(PreVisit, node);
if (visit && node->getCondition())
{
node->getCondition()->traverse(this);
}
if (visit && postVisit)
visitCase(PostVisit, node);
}
//
// Traverse a loop node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseLoop(TIntermLoop *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitLoop(PreVisit, node);
if (visit)
{
if (node->getInit())
node->getInit()->traverse(this);
if (node->getCondition())
node->getCondition()->traverse(this);
if (node->getBody())
node->getBody()->traverse(this);
if (node->getExpression())
node->getExpression()->traverse(this);
}
if (visit && postVisit)
visitLoop(PostVisit, node);
}
//
// Traverse a branch node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseBranch(TIntermBranch *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true;
if (preVisit)
visit = visitBranch(PreVisit, node);
if (visit && node->getExpression())
{
node->getExpression()->traverse(this);
}
if (visit && postVisit)
visitBranch(PostVisit, node);
traverse(node);
}
} // namespace sh
......@@ -10,6 +10,7 @@
#define COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/tree_util/Visit.h"
namespace sh
{
......@@ -17,13 +18,6 @@ namespace sh
class TSymbolTable;
class TSymbolUniqueId;
enum Visit
{
PreVisit,
InVisit,
PostVisit
};
// For traversing the tree. User should derive from this class overriding the visit functions,
// and then pass an object of the subclass to a traverse method of a node.
//
......@@ -73,23 +67,21 @@ class TIntermTraverser : angle::NonCopyable
// The traverse functions contain logic for iterating over the children of the node
// and calling the visit functions in the appropriate places. They also track some
// context that may be used by the visit functions.
virtual void traverseSymbol(TIntermSymbol *node);
virtual void traverseConstantUnion(TIntermConstantUnion *node);
virtual void traverseSwizzle(TIntermSwizzle *node);
// The generic traverse() function is used for nodes that don't need special handling.
// It's templated in order to avoid virtual function calls, this gains around 2% compiler
// performance.
template <typename T>
void traverse(T *node);
// Specialized traverse functions are implemented for node types where traversal logic may need
// to be overridden or where some special bookkeeping needs to be done.
virtual void traverseBinary(TIntermBinary *node);
virtual void traverseUnary(TIntermUnary *node);
virtual void traverseTernary(TIntermTernary *node);
virtual void traverseIfElse(TIntermIfElse *node);
virtual void traverseSwitch(TIntermSwitch *node);
virtual void traverseCase(TIntermCase *node);
virtual void traverseFunctionPrototype(TIntermFunctionPrototype *node);
virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
virtual void traverseAggregate(TIntermAggregate *node);
virtual void traverseBlock(TIntermBlock *node);
virtual void traverseInvariantDeclaration(TIntermInvariantDeclaration *node);
virtual void traverseDeclaration(TIntermDeclaration *node);
virtual void traverseLoop(TIntermLoop *node);
virtual void traverseBranch(TIntermBranch *node);
int getMaxDepth() const { return mMaxDepth; }
......@@ -134,6 +126,10 @@ class TIntermTraverser : angle::NonCopyable
TIntermTraverser *mTraverser;
bool mWithinDepthLimit;
};
// Optimized traversal functions for leaf nodes directly access ScopedNodeInTraversalPath.
friend void TIntermSymbol::traverse(TIntermTraverser *);
friend void TIntermConstantUnion::traverse(TIntermTraverser *);
friend void TIntermFunctionPrototype::traverse(TIntermTraverser *);
TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
......
//
// Copyright (c) 2018 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#ifndef COMPILER_TRANSLATOR_TREEUTIL_VISIT_H_
#define COMPILER_TRANSLATOR_TREEUTIL_VISIT_H_
namespace sh
{
enum Visit
{
PreVisit,
InVisit,
PostVisit
};
} // namespace sh
#endif // COMPILER_TRANSLATOR_TREEUTIL_VISIT_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment