Commit a26ad58d by Olli Etuaho

Track where l-values are required in AST traversal

This functionality is refactored out of EmulatePrecision to be a common feature of TIntermTraverser. This is done since tracking where l-values are required will be useful for other traversers. For example, it will be needed for converting dynamic indexing of matrices and vectors to function calls. This change adds some overhead to all tree traversers, but the overhead is expected to be small for typical shaders which don't contain too many user-defined functions. BUG=angleproject:1116 TEST=angle_unittests Change-Id: I54d34c2b5093ef028f2b24d854c11c0195dc1dbb Reviewed-on: https://chromium-review.googlesource.com/290514Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org>
parent 45808a17
...@@ -292,16 +292,12 @@ bool parentUsesResult(TIntermNode* parent, TIntermNode* node) ...@@ -292,16 +292,12 @@ bool parentUsesResult(TIntermNode* parent, TIntermNode* node)
} // namespace anonymous } // namespace anonymous
EmulatePrecision::EmulatePrecision() EmulatePrecision::EmulatePrecision()
: TIntermTraverser(true, true, true), : TIntermTraverser(true, true, true), mDeclaringVariables(false)
mDeclaringVariables(false),
mInLValue(false),
mInFunctionCallOutParameter(false)
{} {}
void EmulatePrecision::visitSymbol(TIntermSymbol *node) void EmulatePrecision::visitSymbol(TIntermSymbol *node)
{ {
if (canRoundFloat(node->getType()) && if (canRoundFloat(node->getType()) && !mDeclaringVariables && !isLValueRequiredHere())
!mDeclaringVariables && !mInLValue && !mInFunctionCallOutParameter)
{ {
TIntermNode *parent = getParentNode(); TIntermNode *parent = getParentNode();
TIntermNode *replacement = createRoundingFunctionCallNode(node); TIntermNode *replacement = createRoundingFunctionCallNode(node);
...@@ -314,14 +310,6 @@ bool EmulatePrecision::visitBinary(Visit visit, TIntermBinary *node) ...@@ -314,14 +310,6 @@ bool EmulatePrecision::visitBinary(Visit visit, TIntermBinary *node)
{ {
bool visitChildren = true; bool visitChildren = true;
if (node->isAssignment())
{
if (visit == PreVisit)
mInLValue = true;
else if (visit == InVisit)
mInLValue = false;
}
TOperator op = node->getOp(); TOperator op = node->getOp();
// RHS of initialize is not being declared. // RHS of initialize is not being declared.
...@@ -415,22 +403,9 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -415,22 +403,9 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
case EOpSequence: case EOpSequence:
case EOpConstructStruct: case EOpConstructStruct:
// No special handling
break;
case EOpFunction: case EOpFunction:
if (visit == PreVisit)
{
const TIntermSequence &sequence = *(node->getSequence());
TIntermSequence::const_iterator seqIter = sequence.begin();
TIntermAggregate *params = (*seqIter)->getAsAggregate();
ASSERT(params != NULL);
ASSERT(params->getOp() == EOpParameters);
mFunctionMap[node->getName()] = params->getSequence();
}
break; break;
case EOpPrototype: case EOpPrototype:
if (visit == PreVisit)
mFunctionMap[node->getName()] = node->getSequence();
visitChildren = false; visitChildren = false;
break; break;
case EOpParameters: case EOpParameters:
...@@ -457,50 +432,17 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -457,50 +432,17 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpFunctionCall: case EOpFunctionCall:
{ {
// Function call. // Function call.
bool inFunctionMap = (mFunctionMap.find(node->getName()) != mFunctionMap.end());
if (visit == PreVisit) if (visit == PreVisit)
{ {
// User-defined function return values are not rounded, this relies on that // User-defined function return values are not rounded, this relies on that
// calculations producing the value were rounded. // calculations producing the value were rounded.
TIntermNode *parent = getParentNode(); TIntermNode *parent = getParentNode();
if (canRoundFloat(node->getType()) && !inFunctionMap && parentUsesResult(parent, node)) if (canRoundFloat(node->getType()) && !isInFunctionMap(node) &&
parentUsesResult(parent, node))
{ {
TIntermNode *replacement = createRoundingFunctionCallNode(node); TIntermNode *replacement = createRoundingFunctionCallNode(node);
mReplacements.push_back(NodeUpdateEntry(parent, node, replacement, true)); mReplacements.push_back(NodeUpdateEntry(parent, node, replacement, true));
} }
if (inFunctionMap)
{
mSeqIterStack.push_back(mFunctionMap[node->getName()]->begin());
if (mSeqIterStack.back() != mFunctionMap[node->getName()]->end())
{
TQualifier qualifier = (*mSeqIterStack.back())->getAsTyped()->getQualifier();
mInFunctionCallOutParameter = (qualifier == EvqOut || qualifier == EvqInOut);
}
}
else
{
// The function is not user-defined - it is likely built-in texture function.
// Assume that those do not have out parameters.
mInFunctionCallOutParameter = false;
}
}
else if (visit == InVisit)
{
if (inFunctionMap)
{
++mSeqIterStack.back();
TQualifier qualifier = (*mSeqIterStack.back())->getAsTyped()->getQualifier();
mInFunctionCallOutParameter = (qualifier == EvqOut || qualifier == EvqInOut);
}
}
else
{
if (inFunctionMap)
{
mSeqIterStack.pop_back();
mInFunctionCallOutParameter = false;
}
} }
break; break;
} }
...@@ -523,15 +465,10 @@ bool EmulatePrecision::visitUnary(Visit visit, TIntermUnary *node) ...@@ -523,15 +465,10 @@ bool EmulatePrecision::visitUnary(Visit visit, TIntermUnary *node)
case EOpNegative: case EOpNegative:
case EOpVectorLogicalNot: case EOpVectorLogicalNot:
case EOpLogicalNot: case EOpLogicalNot:
break;
case EOpPostIncrement: case EOpPostIncrement:
case EOpPostDecrement: case EOpPostDecrement:
case EOpPreIncrement: case EOpPreIncrement:
case EOpPreDecrement: case EOpPreDecrement:
if (visit == PreVisit)
mInLValue = true;
else if (visit == PostVisit)
mInLValue = false;
break; break;
default: default:
if (canRoundFloat(node->getType()) && visit == PreVisit) if (canRoundFloat(node->getType()) && visit == PreVisit)
......
...@@ -56,20 +56,7 @@ class EmulatePrecision : public TIntermTraverser ...@@ -56,20 +56,7 @@ class EmulatePrecision : public TIntermTraverser
EmulationSet mEmulateCompoundMul; EmulationSet mEmulateCompoundMul;
EmulationSet mEmulateCompoundDiv; EmulationSet mEmulateCompoundDiv;
// Stack of function call parameter iterators
std::vector<TIntermSequence::const_iterator> mSeqIterStack;
bool mDeclaringVariables; bool mDeclaringVariables;
bool mInLValue;
bool mInFunctionCallOutParameter;
struct TStringComparator
{
bool operator() (const TString& a, const TString& b) const { return a.compare(b) < 0; }
};
// Map from function names to their parameter sequences
std::map<TString, TIntermSequence*, TStringComparator> mFunctionMap;
}; };
#endif // COMPILER_TRANSLATOR_EMULATE_PRECISION_H_ #endif // COMPILER_TRANSLATOR_EMULATE_PRECISION_H_
...@@ -609,7 +609,9 @@ class TIntermTraverser : angle::NonCopyable ...@@ -609,7 +609,9 @@ class TIntermTraverser : angle::NonCopyable
postVisit(postVisit), postVisit(postVisit),
mDepth(0), mDepth(0),
mMaxDepth(0), mMaxDepth(0),
mTemporaryIndex(nullptr) mTemporaryIndex(nullptr),
mOperatorRequiresLValue(false),
mInFunctionCallOutParameter(false)
{ {
} }
virtual ~TIntermTraverser() {} virtual ~TIntermTraverser() {}
...@@ -671,6 +673,35 @@ class TIntermTraverser : angle::NonCopyable ...@@ -671,6 +673,35 @@ class TIntermTraverser : angle::NonCopyable
// Start creating temporary symbols from the given temporary symbol index + 1. // Start creating temporary symbols from the given temporary symbol index + 1.
void useTemporaryIndex(unsigned int *temporaryIndex); void useTemporaryIndex(unsigned int *temporaryIndex);
// Track whether an l-value is required in the node that is currently being traversed.
// These functions are intended to be called only from traversal functions, not from subclasses
// of TIntermTraverser.
// Use isLValueRequiredHere instead to check all conditions which require an l-value.
void setOperatorRequiresLValue(bool lValueRequired)
{
mOperatorRequiresLValue = lValueRequired;
}
bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
// Add a function encountered during traversal to the function map. Intended to be called only
// from traversal functions, not from subclasses of TIntermTraverser.
void addToFunctionMap(const TString &name, TIntermSequence *paramSequence);
// Return true if the prototype or definition of the function being called has been encountered
// during traversal.
bool isInFunctionMap(const TIntermAggregate *callNode) const;
// Return the parameters sequence from the function definition or prototype.
TIntermSequence *getFunctionParameters(const TIntermAggregate *callNode);
// Track whether an l-value is required inside a function call.
// This function is intended to be called only from traversal functions, not from traverers.
void setInFunctionCallOutParameter(bool inOutParameter);
bool isLValueRequiredHere() const
{
return mOperatorRequiresLValue || mInFunctionCallOutParameter;
}
protected: protected:
int mDepth; int mDepth;
int mMaxDepth; int mMaxDepth;
...@@ -771,6 +802,17 @@ class TIntermTraverser : angle::NonCopyable ...@@ -771,6 +802,17 @@ class TIntermTraverser : angle::NonCopyable
std::vector<ParentBlock> mParentBlockStack; std::vector<ParentBlock> mParentBlockStack;
unsigned int *mTemporaryIndex; unsigned int *mTemporaryIndex;
bool mOperatorRequiresLValue;
bool mInFunctionCallOutParameter;
struct TStringComparator
{
bool operator()(const TString &a, const TString &b) const { return a.compare(b) < 0; }
};
// Map from mangled function names to their parameter sequences
TMap<TString, TIntermSequence *, TStringComparator> mFunctionMap;
}; };
// //
......
...@@ -96,6 +96,28 @@ void TIntermTraverser::nextTemporaryIndex() ...@@ -96,6 +96,28 @@ void TIntermTraverser::nextTemporaryIndex()
++(*mTemporaryIndex); ++(*mTemporaryIndex);
} }
void TIntermTraverser::addToFunctionMap(const TString &name, TIntermSequence *paramSequence)
{
mFunctionMap[name] = paramSequence;
}
bool TIntermTraverser::isInFunctionMap(const TIntermAggregate *callNode) const
{
ASSERT(callNode->getOp() == EOpFunctionCall || callNode->getOp() == EOpInternalFunctionCall);
return (mFunctionMap.find(callNode->getName()) != mFunctionMap.end());
}
TIntermSequence *TIntermTraverser::getFunctionParameters(const TIntermAggregate *callNode)
{
ASSERT(isInFunctionMap(callNode));
return mFunctionMap[callNode->getName()];
}
void TIntermTraverser::setInFunctionCallOutParameter(bool inOutParameter)
{
mInFunctionCallOutParameter = inOutParameter;
}
// //
// Traverse the intermediate representation tree, and // Traverse the intermediate representation tree, and
// call a node type specific function for each node. // call a node type specific function for each node.
...@@ -140,12 +162,23 @@ void TIntermBinary::traverse(TIntermTraverser *it) ...@@ -140,12 +162,23 @@ void TIntermBinary::traverse(TIntermTraverser *it)
{ {
it->incrementDepth(this); it->incrementDepth(this);
if (isAssignment())
{
// Some binary operations like indexing can be inside an l-value.
// TODO(oetuaho@nvidia.com): Now the code doesn't unset operatorRequiresLValue for the
// index, fix this.
it->setOperatorRequiresLValue(true);
}
if (mLeft) if (mLeft)
mLeft->traverse(it); mLeft->traverse(it);
if (it->inVisit) if (it->inVisit)
visit = it->visitBinary(InVisit, this); visit = it->visitBinary(InVisit, this);
if (isAssignment())
it->setOperatorRequiresLValue(false);
if (visit && mRight) if (visit && mRight)
mRight->traverse(it); mRight->traverse(it);
...@@ -170,9 +203,26 @@ void TIntermUnary::traverse(TIntermTraverser *it) ...@@ -170,9 +203,26 @@ void TIntermUnary::traverse(TIntermTraverser *it)
if (it->preVisit) if (it->preVisit)
visit = it->visitUnary(PreVisit, this); visit = it->visitUnary(PreVisit, this);
if (visit) { if (visit)
{
it->incrementDepth(this); it->incrementDepth(this);
switch (getOp())
{
case EOpPostIncrement:
case EOpPostDecrement:
case EOpPreIncrement:
case EOpPreDecrement:
it->setOperatorRequiresLValue(true);
break;
default:
break;
}
mOperand->traverse(it); mOperand->traverse(it);
it->setOperatorRequiresLValue(false);
it->decrementDepth(); it->decrementDepth();
} }
...@@ -187,36 +237,87 @@ void TIntermAggregate::traverse(TIntermTraverser *it) ...@@ -187,36 +237,87 @@ void TIntermAggregate::traverse(TIntermTraverser *it)
{ {
bool visit = true; bool visit = true;
switch (mOp)
{
case EOpFunction:
{
TIntermAggregate *params = mSequence.front()->getAsAggregate();
ASSERT(params != nullptr);
ASSERT(params->getOp() == EOpParameters);
it->addToFunctionMap(mName, params->getSequence());
break;
}
case EOpPrototype:
it->addToFunctionMap(mName, &mSequence);
break;
default:
break;
}
if (it->preVisit) if (it->preVisit)
visit = it->visitAggregate(PreVisit, this); visit = it->visitAggregate(PreVisit, this);
if (visit) if (visit)
{ {
if (mOp == EOpSequence) bool inFunctionMap = false;
it->pushParentBlock(this); if (mOp == EOpFunctionCall)
{
inFunctionMap = it->isInFunctionMap(this);
if (!inFunctionMap)
{
// The function is not user-defined - it is likely built-in texture function.
// Assume that those do not have out parameters.
it->setInFunctionCallOutParameter(false);
}
}
it->incrementDepth(this); it->incrementDepth(this);
for (TIntermSequence::iterator sit = mSequence.begin(); if (inFunctionMap)
sit != mSequence.end(); sit++)
{ {
(*sit)->traverse(it); TIntermSequence *params = it->getFunctionParameters(this);
TIntermSequence::iterator paramIter = params->begin();
if (visit && it->inVisit) for (auto *child : mSequence)
{ {
if (*sit != mSequence.back()) ASSERT(paramIter != params->end());
visit = it->visitAggregate(InVisit, this); TQualifier qualifier = (*paramIter)->getAsTyped()->getQualifier();
it->setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
child->traverse(it);
if (visit && it->inVisit)
{
if (child != mSequence.back())
visit = it->visitAggregate(InVisit, this);
}
++paramIter;
} }
it->setInFunctionCallOutParameter(false);
}
else
{
if (mOp == EOpSequence) if (mOp == EOpSequence)
it->pushParentBlock(this);
for (auto *child : mSequence)
{ {
it->incrementParentBlockPos(); child->traverse(it);
if (visit && it->inVisit)
{
if (child != mSequence.back())
visit = it->visitAggregate(InVisit, this);
}
if (mOp == EOpSequence)
it->incrementParentBlockPos();
} }
if (mOp == EOpSequence)
it->popParentBlock();
} }
it->decrementDepth(); it->decrementDepth();
if (mOp == EOpSequence)
it->popParentBlock();
} }
if (visit && it->postVisit) if (visit && it->postVisit)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment