Commit d4f303ee by Olli Etuaho

Refactoring: Make creating temporary symbols in AST traversal reusable

Temporary symbols will also be needed to store temporary arrays when complex array expressions are unfolded. Also clear tree update related structures at the end of updateTree(), so that the traverser can be reused for several rounds of replacement more easily, and remove unnecessary InVisit step from UnfoldShortCircuitToIf. BUG=angleproject:971 TEST=angle_end2end_tests, WebGL conformance tests Change-Id: Iecdd3008d43f01b02fe344ccde8614f70e6c0c65 Reviewed-on: https://chromium-review.googlesource.com/272121Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com>
parent bf790420
...@@ -1879,4 +1879,8 @@ void TIntermTraverser::updateTree() ...@@ -1879,4 +1879,8 @@ void TIntermTraverser::updateTree()
ASSERT(replaced); ASSERT(replaced);
UNUSED_ASSERTION_VARIABLE(replaced); UNUSED_ASSERTION_VARIABLE(replaced);
} }
mInsertions.clear();
mReplacements.clear();
mMultiReplacements.clear();
} }
...@@ -594,7 +594,10 @@ class TIntermTraverser : angle::NonCopyable ...@@ -594,7 +594,10 @@ class TIntermTraverser : angle::NonCopyable
postVisit(postVisit), postVisit(postVisit),
rightToLeft(rightToLeft), rightToLeft(rightToLeft),
mDepth(0), mDepth(0),
mMaxDepth(0) {} mMaxDepth(0),
mTemporaryIndex(nullptr)
{
}
virtual ~TIntermTraverser() {} virtual ~TIntermTraverser() {}
virtual void visitSymbol(TIntermSymbol *) {} virtual void visitSymbol(TIntermSymbol *) {}
...@@ -647,6 +650,9 @@ class TIntermTraverser : angle::NonCopyable ...@@ -647,6 +650,9 @@ class TIntermTraverser : angle::NonCopyable
// this function after traversal to perform them. // this function after traversal to perform them.
void updateTree(); void updateTree();
// Start creating temporary symbols from the given temporary symbol index + 1.
void useTemporaryIndex(unsigned int *temporaryIndex);
protected: protected:
int mDepth; int mDepth;
int mMaxDepth; int mMaxDepth;
...@@ -716,6 +722,15 @@ class TIntermTraverser : angle::NonCopyable ...@@ -716,6 +722,15 @@ class TIntermTraverser : angle::NonCopyable
// supported. // supported.
void insertStatementsInParentBlock(const TIntermSequence &insertions); void insertStatementsInParentBlock(const TIntermSequence &insertions);
// Helper to create a temporary symbol node.
TIntermSymbol *createTempSymbol(const TType &type);
// Create a node that initializes the current temporary symbol with initializer.
TIntermAggregate *createTempInitDeclaration(TIntermTyped *initializer);
// Create a node that assigns rightNode to the current temporary symbol.
TIntermBinary *createTempAssignment(TIntermTyped *rightNode);
// Increment temporary symbol index.
void nextTemporaryIndex();
private: private:
struct ParentBlock struct ParentBlock
{ {
...@@ -730,6 +745,8 @@ class TIntermTraverser : angle::NonCopyable ...@@ -730,6 +745,8 @@ class TIntermTraverser : angle::NonCopyable
}; };
// All the code blocks from the root to the current node's parent during traversal. // All the code blocks from the root to the current node's parent during traversal.
std::vector<ParentBlock> mParentBlockStack; std::vector<ParentBlock> mParentBlockStack;
unsigned int *mTemporaryIndex;
}; };
// //
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
// //
#include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNode.h"
#include "compiler/translator/InfoSink.h"
void TIntermTraverser::pushParentBlock(TIntermAggregate *node) void TIntermTraverser::pushParentBlock(TIntermAggregate *node)
{ {
...@@ -35,6 +36,56 @@ void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &inse ...@@ -35,6 +36,56 @@ void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &inse
mInsertions.push_back(insert); mInsertions.push_back(insert);
} }
TIntermSymbol *TIntermTraverser::createTempSymbol(const TType &type)
{
// Each traversal uses at most one temporary variable, so the index stays the same within a single traversal.
TInfoSinkBase symbolNameOut;
ASSERT(mTemporaryIndex != nullptr);
symbolNameOut << "s" << (*mTemporaryIndex);
TString symbolName = symbolNameOut.c_str();
TIntermSymbol *node = new TIntermSymbol(0, symbolName, type);
node->setInternal(true);
node->getTypePointer()->setQualifier(EvqTemporary);
return node;
}
TIntermAggregate *TIntermTraverser::createTempInitDeclaration(TIntermTyped *initializer)
{
ASSERT(initializer != nullptr);
TIntermSymbol *tempSymbol = createTempSymbol(initializer->getType());
TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
TIntermBinary *tempInit = new TIntermBinary(EOpInitialize);
tempInit->setLeft(tempSymbol);
tempInit->setRight(initializer);
tempInit->setType(tempSymbol->getType());
tempDeclaration->getSequence()->push_back(tempInit);
return tempDeclaration;
}
TIntermBinary *TIntermTraverser::createTempAssignment(TIntermTyped *rightNode)
{
ASSERT(rightNode != nullptr);
TIntermSymbol *tempSymbol = createTempSymbol(rightNode->getType());
TIntermBinary *assignment = new TIntermBinary(EOpAssign);
assignment->setLeft(tempSymbol);
assignment->setRight(rightNode);
assignment->setType(tempSymbol->getType());
return assignment;
}
void TIntermTraverser::useTemporaryIndex(unsigned int *temporaryIndex)
{
mTemporaryIndex = temporaryIndex;
}
void TIntermTraverser::nextTemporaryIndex()
{
ASSERT(mTemporaryIndex != nullptr);
++(*mTemporaryIndex);
}
// //
// Traverse the intermediate representation tree, and // Traverse the intermediate representation tree, and
// call a node type specific function for each node. // call a node type specific function for each node.
......
...@@ -25,8 +25,10 @@ void TranslatorHLSL::translate(TIntermNode *root, int compileOptions) ...@@ -25,8 +25,10 @@ void TranslatorHLSL::translate(TIntermNode *root, int compileOptions)
SeparateDeclarations(root); SeparateDeclarations(root);
unsigned int temporaryIndex = 0;
// Note that SeparateDeclarations needs to be run before UnfoldShortCircuitToIf. // Note that SeparateDeclarations needs to be run before UnfoldShortCircuitToIf.
UnfoldShortCircuitToIf(root); UnfoldShortCircuitToIf(root, &temporaryIndex);
// Note that SeparateDeclarations needs to be run before SeparateArrayInitialization. // Note that SeparateDeclarations needs to be run before SeparateArrayInitialization.
SeparateArrayInitialization(root); SeparateArrayInitialization(root);
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "compiler/translator/UnfoldShortCircuitToIf.h" #include "compiler/translator/UnfoldShortCircuitToIf.h"
#include "compiler/translator/InfoSink.h"
#include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNode.h"
namespace namespace
...@@ -31,60 +30,17 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser ...@@ -31,60 +30,17 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser
bool foundShortCircuit() const { return mFoundShortCircuit; } bool foundShortCircuit() const { return mFoundShortCircuit; }
protected: protected:
int mTemporaryIndex;
// Marked to true once an operation that needs to be unfolded has been found. // Marked to true once an operation that needs to be unfolded has been found.
// After that, no more unfolding is performed on that traversal. // After that, no more unfolding is performed on that traversal.
bool mFoundShortCircuit; bool mFoundShortCircuit;
TIntermSymbol *createTempSymbol(const TType &type);
TIntermAggregate *createTempInitDeclaration(const TType &type, TIntermTyped *initializer);
TIntermBinary *createTempAssignment(const TType &type, TIntermTyped *rightNode);
}; };
UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser() UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
: TIntermTraverser(true, true, true), : TIntermTraverser(true, false, true),
mTemporaryIndex(0),
mFoundShortCircuit(false) mFoundShortCircuit(false)
{ {
} }
TIntermSymbol *UnfoldShortCircuitTraverser::createTempSymbol(const TType &type)
{
// Each traversal uses at most one temporary variable, so the index stays the same within a single traversal.
TInfoSinkBase symbolNameOut;
symbolNameOut << "s" << mTemporaryIndex;
TString symbolName = symbolNameOut.c_str();
TIntermSymbol *node = new TIntermSymbol(0, symbolName, type);
node->setInternal(true);
return node;
}
TIntermAggregate *UnfoldShortCircuitTraverser::createTempInitDeclaration(const TType &type, TIntermTyped *initializer)
{
ASSERT(initializer != nullptr);
TIntermSymbol *tempSymbol = createTempSymbol(type);
TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
TIntermBinary *tempInit = new TIntermBinary(EOpInitialize);
tempInit->setLeft(tempSymbol);
tempInit->setRight(initializer);
tempInit->setType(type);
tempDeclaration->getSequence()->push_back(tempInit);
return tempDeclaration;
}
TIntermBinary *UnfoldShortCircuitTraverser::createTempAssignment(const TType &type, TIntermTyped *rightNode)
{
ASSERT(rightNode != nullptr);
TIntermSymbol *tempSymbol = createTempSymbol(type);
TIntermBinary *assignment = new TIntermBinary(EOpAssign);
assignment->setLeft(tempSymbol);
assignment->setRight(rightNode);
assignment->setType(type);
return assignment;
}
bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
if (mFoundShortCircuit) if (mFoundShortCircuit)
...@@ -108,10 +64,12 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -108,10 +64,12 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
TIntermSequence insertions; TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
insertions.push_back(createTempInitDeclaration(boolType, node->getLeft())); ASSERT(node->getLeft()->getType() == boolType);
insertions.push_back(createTempInitDeclaration(node->getLeft()));
TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence);
assignRightBlock->getSequence()->push_back(createTempAssignment(boolType, node->getRight())); ASSERT(node->getRight()->getType() == boolType);
assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight()));
TIntermUnary *notTempSymbol = new TIntermUnary(EOpLogicalNot, boolType); TIntermUnary *notTempSymbol = new TIntermUnary(EOpLogicalNot, boolType);
notTempSymbol->setOperand(createTempSymbol(boolType)); notTempSymbol->setOperand(createTempSymbol(boolType));
...@@ -133,10 +91,12 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -133,10 +91,12 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
TIntermSequence insertions; TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
insertions.push_back(createTempInitDeclaration(boolType, node->getLeft())); ASSERT(node->getLeft()->getType() == boolType);
insertions.push_back(createTempInitDeclaration(node->getLeft()));
TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence);
assignRightBlock->getSequence()->push_back(createTempAssignment(boolType, node->getRight())); ASSERT(node->getRight()->getType() == boolType);
assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight()));
TIntermSelection *ifNode = new TIntermSelection(createTempSymbol(boolType), assignRightBlock, nullptr); TIntermSelection *ifNode = new TIntermSelection(createTempSymbol(boolType), assignRightBlock, nullptr);
insertions.push_back(ifNode); insertions.push_back(ifNode);
...@@ -169,11 +129,11 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection * ...@@ -169,11 +129,11 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *
insertions.push_back(tempDeclaration); insertions.push_back(tempDeclaration);
TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *trueAssignment = createTempAssignment(node->getType(), node->getTrueBlock()->getAsTyped()); TIntermBinary *trueAssignment = createTempAssignment(node->getTrueBlock()->getAsTyped());
trueBlock->getSequence()->push_back(trueAssignment); trueBlock->getSequence()->push_back(trueAssignment);
TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *falseAssignment = createTempAssignment(node->getType(), node->getFalseBlock()->getAsTyped()); TIntermBinary *falseAssignment = createTempAssignment(node->getFalseBlock()->getAsTyped());
falseBlock->getSequence()->push_back(falseAssignment); falseBlock->getSequence()->push_back(falseAssignment);
TIntermSelection *ifNode = new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock); TIntermSelection *ifNode = new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
...@@ -235,17 +195,16 @@ bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate * ...@@ -235,17 +195,16 @@ bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate *
void UnfoldShortCircuitTraverser::nextIteration() void UnfoldShortCircuitTraverser::nextIteration()
{ {
mFoundShortCircuit = false; mFoundShortCircuit = false;
mTemporaryIndex++; nextTemporaryIndex();
mReplacements.clear();
mMultiReplacements.clear();
mInsertions.clear();
} }
} // namespace } // namespace
void UnfoldShortCircuitToIf(TIntermNode *root) void UnfoldShortCircuitToIf(TIntermNode *root, unsigned int *temporaryIndex)
{ {
UnfoldShortCircuitTraverser traverser; UnfoldShortCircuitTraverser traverser;
ASSERT(temporaryIndex != nullptr);
traverser.useTemporaryIndex(temporaryIndex);
// Unfold one operator at a time, and reset the traverser between iterations. // Unfold one operator at a time, and reset the traverser between iterations.
do do
{ {
......
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
class TIntermNode; class TIntermNode;
void UnfoldShortCircuitToIf(TIntermNode *root); void UnfoldShortCircuitToIf(TIntermNode *root, unsigned int *temporaryIndex);
#endif // COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_ #endif // COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment