Commit 336b1470 by Olli Etuaho Committed by Commit Bot

Split TIntermFunctionDefinition from TIntermAggregate

This makes the code easier to understand. Function definition nodes always have just two children, the parameters node and the function body node, so there was no proper reason why they should be aggregate nodes. As a part of this change, intermediate output is modified to print symbol table ids of functions so that debugging function id related functionality will be easier in the future. After this patch, TIntermAggregate is still used for function prototypes, function parameter lists, function calls, variable and invariant declarations and the comma (sequence) operator. BUG=angleproject:1490 TEST=angle_unittests, angle_end2end_tests Change-Id: Ib88b4ca5d21abd5f126836ca5900d0baecabd19e Reviewed-on: https://chromium-review.googlesource.com/394707 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org>
parent 138ec92f
......@@ -31,7 +31,7 @@ class PullGradient : public TIntermTraverser
ASSERT(index < metadataList->size());
}
void traverse(TIntermAggregate *node)
void traverse(TIntermFunctionDefinition *node)
{
node->traverse(this);
ASSERT(mParents.empty());
......@@ -158,7 +158,7 @@ class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
{
}
void traverse(TIntermAggregate *node)
void traverse(TIntermFunctionDefinition *node)
{
node->traverse(this);
ASSERT(mLoopsAndSwitches.empty());
......@@ -328,7 +328,7 @@ class PushDiscontinuousLoops : public TIntermTraverser
{
}
void traverse(TIntermAggregate *node)
void traverse(TIntermFunctionDefinition *node)
{
node->traverse(this);
ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
......
......@@ -31,16 +31,15 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser
private:
AddDefaultReturnStatementsTraverser() : TIntermTraverser(true, false, false) {}
static bool IsFunctionWithoutReturnStatement(TIntermAggregate *node, TType *returnType)
static bool IsFunctionWithoutReturnStatement(TIntermFunctionDefinition *node, TType *returnType)
{
*returnType = node->getType();
if (node->getOp() != EOpFunction || node->getType().getBasicType() == EbtVoid)
if (node->getType().getBasicType() == EbtVoid)
{
return false;
}
TIntermBlock *bodyNode = node->getSequence()->back()->getAsBlock();
ASSERT(bodyNode);
TIntermBlock *bodyNode = node->getBody();
TIntermBranch *returnNode = bodyNode->getSequence()->back()->getAsBranchNode();
if (returnNode != nullptr && returnNode->getFlowOp() == EOpReturn)
{
......@@ -50,7 +49,7 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser
return true;
}
bool visitAggregate(Visit, TIntermAggregate *node) override
bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *node) override
{
TType returnType;
if (IsFunctionWithoutReturnStatement(node, &returnType))
......@@ -58,7 +57,7 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser
TIntermBranch *branch =
new TIntermBranch(EOpReturn, TIntermTyped::CreateZero(returnType));
TIntermBlock *bodyNode = node->getSequence()->back()->getAsBlock();
TIntermBlock *bodyNode = node->getBody();
bodyNode->getSequence()->push_back(branch);
return false;
......
......@@ -60,6 +60,7 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
private:
ArrayReturnValueToOutParameterTraverser();
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBranch(Visit visit, TIntermBranch *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override;
......@@ -81,35 +82,47 @@ ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser
{
}
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
Visit visit,
TIntermFunctionDefinition *node)
{
if (visit == PreVisit)
if (node->isArray() && visit == PreVisit)
{
if (node->isArray())
{
if (node->getOp() == EOpFunction)
{
// Replace the parameters child node of the function definition with another node
// that has the out parameter added.
// Also set the function to return void.
// Replace the parameters child node of the function definition with another node
// that has the out parameter added.
// Also set the function to return void.
TIntermAggregate *params = node->getSequence()->front()->getAsAggregate();
ASSERT(params != nullptr && params->getOp() == EOpParameters);
TIntermAggregate *params = node->getFunctionParameters();
ASSERT(params != nullptr && params->getOp() == EOpParameters);
TIntermAggregate *replacementParams = new TIntermAggregate;
replacementParams->setOp(EOpParameters);
CopyAggregateChildren(params, replacementParams);
replacementParams->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
replacementParams->setLine(params->getLine());
TIntermAggregate *replacementParams = new TIntermAggregate;
replacementParams->setOp(EOpParameters);
CopyAggregateChildren(params, replacementParams);
replacementParams->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
replacementParams->setLine(params->getLine());
queueReplacementWithParent(node, params, replacementParams,
OriginalNode::IS_DROPPED);
queueReplacementWithParent(node, params, replacementParams, OriginalNode::IS_DROPPED);
node->setType(TType(EbtVoid));
node->setType(TType(EbtVoid));
mInFunctionWithArrayReturnValue = true;
}
else if (node->getOp() == EOpPrototype)
mInFunctionWithArrayReturnValue = true;
}
if (visit == PostVisit)
{
// This isn't conditional on node->isArray() since the type has already been changed on
// PreVisit.
mInFunctionWithArrayReturnValue = false;
}
return true;
}
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
if (visit == PreVisit)
{
if (node->isArray())
{
if (node->getOp() == EOpPrototype)
{
// Replace the whole prototype node with another node that has the out parameter added.
TIntermAggregate *replacement = new TIntermAggregate;
......@@ -149,13 +162,6 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter
}
}
}
else if (visit == PostVisit)
{
if (node->getOp() == EOpFunction)
{
mInFunctionWithArrayReturnValue = false;
}
}
return true;
}
......
......@@ -94,13 +94,39 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
}
std::set<CreatorFunctionData*> callees;
TIntermAggregate *node;
TIntermFunctionDefinition *node;
TString name;
size_t index;
bool indexAssigned;
bool visiting;
};
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
{
// Create the record if need be and remember the node.
if (visit == PreVisit)
{
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName());
if (it == mFunctions.end())
{
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getName()];
}
else
{
mCurrentFunction = &it->second;
}
mCurrentFunction->node = node;
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
}
else if (visit == PostVisit)
{
mCurrentFunction = nullptr;
}
return true;
}
// Aggregates the AST node for each function as well as the name of the functions called by it
bool visitAggregate(Visit visit, TIntermAggregate *node) override
{
......@@ -114,31 +140,6 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
record.name = node->getFunctionSymbolInfo()->getName();
}
break;
case EOpFunction:
{
// Function definition, create the record if need be and remember the node.
if (visit == PreVisit)
{
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName());
if (it == mFunctions.end())
{
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getName()];
}
else
{
mCurrentFunction = &it->second;
}
mCurrentFunction->node = node;
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
}
else if (visit == PostVisit)
{
mCurrentFunction = nullptr;
}
break;
}
case EOpFunctionCall:
{
// Function call, add the callees
......
......@@ -41,7 +41,7 @@ class CallDAG : angle::NonCopyable
struct Record
{
std::string name;
TIntermAggregate *node;
TIntermFunctionDefinition *node;
std::vector<int> callees;
};
......
......@@ -763,22 +763,31 @@ class TCompiler::UnusedPredicate
bool operator ()(TIntermNode *node)
{
const TIntermAggregate *asAggregate = node->getAsAggregate();
const TIntermFunctionDefinition *asFunction = node->getAsFunctionDefinition();
if (asAggregate == nullptr)
const TFunctionSymbolInfo *functionInfo = nullptr;
if (asFunction)
{
return false;
functionInfo = asFunction->getFunctionSymbolInfo();
}
if (!(asAggregate->getOp() == EOpFunction || asAggregate->getOp() == EOpPrototype))
else if (asAggregate)
{
if (asAggregate->getOp() == EOpPrototype)
{
functionInfo = asAggregate->getFunctionSymbolInfo();
}
}
if (functionInfo == nullptr)
{
return false;
}
size_t callDagIndex = mCallDag->findIndex(asAggregate->getFunctionSymbolInfo());
size_t callDagIndex = mCallDag->findIndex(functionInfo);
if (callDagIndex == CallDAG::InvalidIndex)
{
// This happens only for unimplemented prototypes which are thus unused
ASSERT(asAggregate->getOp() == EOpPrototype);
ASSERT(asAggregate && asAggregate->getOp() == EOpPrototype);
return true;
}
......
......@@ -38,18 +38,16 @@ TIntermAggregate *CreateFunctionPrototypeNode(const char *name, const int functi
return functionNode;
}
TIntermAggregate *CreateFunctionDefinitionNode(const char *name,
TIntermBlock *functionBody,
const int functionId)
TIntermFunctionDefinition *CreateFunctionDefinitionNode(const char *name,
TIntermBlock *functionBody,
const int functionId)
{
TIntermAggregate *functionNode = new TIntermAggregate(EOpFunction);
TType returnType(EbtVoid);
TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
functionNode->getSequence()->push_back(paramsNode);
functionNode->getSequence()->push_back(functionBody);
TIntermFunctionDefinition *functionNode =
new TIntermFunctionDefinition(returnType, paramsNode, functionBody);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
TType returnType(EbtVoid);
functionNode->setType(returnType);
functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode;
}
......@@ -156,25 +154,22 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
{
functionBody->push_back(deferredInit);
}
TIntermAggregate *functionDefinition =
TIntermFunctionDefinition *functionDefinition =
CreateFunctionDefinitionNode(functionName, functionBodyNode, initFunctionId);
root->getSequence()->push_back(functionDefinition);
// Insert call into main function
for (TIntermNode *node : *root->getSequence())
{
TIntermAggregate *nodeAgg = node->getAsAggregate();
if (nodeAgg != nullptr && nodeAgg->getOp() == EOpFunction &&
nodeAgg->getFunctionSymbolInfo()->isMain())
TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition();
if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain())
{
TIntermAggregate *functionCallNode =
CreateFunctionCallNode(functionName, initFunctionId);
TIntermNode *mainBody = nodeAgg->getSequence()->back();
TIntermBlock *mainBodyBlock = mainBody->getAsBlock();
ASSERT(mainBodyBlock != nullptr);
mainBodyBlock->getSequence()->insert(mainBodyBlock->getSequence()->begin(),
functionCallNode);
TIntermBlock *mainBody = nodeFunction->getBody();
ASSERT(mainBody != nullptr);
mainBody->getSequence()->insert(mainBody->getSequence()->begin(), functionCallNode);
}
}
}
......
......@@ -34,7 +34,7 @@ class GLFragColorBroadcastTraverser : public TIntermTraverser
protected:
void visitSymbol(TIntermSymbol *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
TIntermBinary *constructGLFragDataNode(int index) const;
TIntermBinary *constructGLFragDataAssignNode(int index) const;
......@@ -74,24 +74,15 @@ void GLFragColorBroadcastTraverser::visitSymbol(TIntermSymbol *node)
}
}
bool GLFragColorBroadcastTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
bool GLFragColorBroadcastTraverser::visitFunctionDefinition(Visit visit,
TIntermFunctionDefinition *node)
{
switch (node->getOp())
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
case EOpFunction:
// Function definition.
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 2);
TIntermBlock *body = (*sequence)[1]->getAsBlock();
ASSERT(body);
mMainSequence = body->getSequence();
}
break;
default:
break;
TIntermBlock *body = node->getBody();
ASSERT(body);
mMainSequence = body->getSequence();
}
return true;
}
......
......@@ -604,7 +604,6 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
switch (node->getOp())
{
case EOpConstructStruct:
case EOpFunction:
break;
case EOpPrototype:
visitChildren = false;
......
......@@ -28,8 +28,9 @@ class VariableInitializer : public TIntermTraverser
bool visitIfElse(Visit, TIntermIfElse *node) override { return false; }
bool visitLoop(Visit, TIntermLoop *node) override { return false; }
bool visitBranch(Visit, TIntermBranch *node) override { return false; }
bool visitAggregate(Visit, TIntermAggregate *node) override { return false; }
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
private:
void insertInitCode(TIntermSequence *sequence);
......@@ -40,31 +41,17 @@ class VariableInitializer : public TIntermTraverser
// VariableInitializer implementation.
bool VariableInitializer::visitAggregate(Visit visit, TIntermAggregate *node)
bool VariableInitializer::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
bool visitChildren = !mCodeInserted;
switch (node->getOp())
// Function definition.
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
case EOpFunction:
{
// Function definition.
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 2);
TIntermBlock *body = (*sequence)[1]->getAsBlock();
ASSERT(body);
insertInitCode(body->getSequence());
mCodeInserted = true;
}
break;
}
default:
visitChildren = false;
break;
TIntermBlock *body = node->getBody();
insertInitCode(body->getSequence());
mCodeInserted = true;
}
return visitChildren;
return false;
}
void VariableInitializer::insertInitCode(TIntermSequence *sequence)
......
......@@ -204,6 +204,13 @@ bool TIntermUnary::replaceChildNode(
return false;
}
bool TIntermFunctionDefinition::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mParameters, TIntermAggregate, original, replacement);
REPLACE_IF_IS(mBody, TIntermBlock, original, replacement);
return false;
}
bool TIntermAggregate::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
......
......@@ -32,6 +32,7 @@ class TDiagnostics;
class TIntermTraverser;
class TIntermAggregate;
class TIntermBlock;
class TIntermFunctionDefinition;
class TIntermSwizzle;
class TIntermBinary;
class TIntermUnary;
......@@ -94,6 +95,7 @@ class TIntermNode : angle::NonCopyable
virtual void traverse(TIntermTraverser *) = 0;
virtual TIntermTyped *getAsTyped() { return 0; }
virtual TIntermConstantUnion *getAsConstantUnion() { return 0; }
virtual TIntermFunctionDefinition *getAsFunctionDefinition() { return nullptr; }
virtual TIntermAggregate *getAsAggregate() { return 0; }
virtual TIntermBlock *getAsBlock() { return nullptr; }
virtual TIntermSwizzle *getAsSwizzleNode() { return nullptr; }
......@@ -550,6 +552,47 @@ class TFunctionSymbolInfo
int mId;
};
// Node for function definitions.
class TIntermFunctionDefinition : public TIntermTyped
{
public:
// TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor
// parameters.
TIntermFunctionDefinition(const TType &type, TIntermAggregate *parameters, TIntermBlock *body)
: TIntermTyped(type), mParameters(parameters), mBody(body)
{
ASSERT(parameters != nullptr);
ASSERT(body != nullptr);
}
TIntermFunctionDefinition *getAsFunctionDefinition() override { return this; }
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *deepCopy() const override
{
UNREACHABLE();
return nullptr;
}
bool hasSideEffects() const override
{
UNREACHABLE();
return true;
}
TIntermAggregate *getFunctionParameters() const { return mParameters; }
TIntermBlock *getBody() const { return mBody; }
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
private:
TIntermAggregate *mParameters;
TIntermBlock *mBody;
TFunctionSymbolInfo mFunctionInfo;
};
typedef TVector<TIntermNode *> TIntermSequence;
typedef TVector<int> TQualifierList;
......@@ -805,6 +848,10 @@ class TIntermTraverser : angle::NonCopyable
virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
return true;
}
virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; }
virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; }
virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; }
......@@ -823,6 +870,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void traverseIfElse(TIntermIfElse *node);
virtual void traverseSwitch(TIntermSwitch *node);
virtual void traverseCase(TIntermCase *node);
virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
virtual void traverseAggregate(TIntermAggregate *node);
virtual void traverseBlock(TIntermBlock *node);
virtual void traverseLoop(TIntermLoop *node);
......@@ -1039,6 +1087,7 @@ class TLValueTrackingTraverser : public TIntermTraverser
void traverseBinary(TIntermBinary *node) final;
void traverseUnary(TIntermUnary *node) final;
void traverseFunctionDefinition(TIntermFunctionDefinition *node) final;
void traverseAggregate(TIntermAggregate *node) final;
protected:
......
......@@ -58,6 +58,11 @@ void TIntermCase::traverse(TIntermTraverser *it)
it->traverseCase(this);
}
void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
{
it->traverseFunctionDefinition(this);
}
void TIntermBlock::traverse(TIntermTraverser *it)
{
it->traverseBlock(this);
......@@ -424,6 +429,32 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
visitUnary(PostVisit, node);
}
// Traverse a function definition node.
void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
{
bool visit = true;
if (preVisit)
visit = visitFunctionDefinition(PreVisit, node);
if (visit)
{
incrementDepth(node);
mInGlobalScope = false;
node->getFunctionParameters()->traverse(this);
if (inVisit)
visit = visitFunctionDefinition(InVisit, node);
node->getBody()->traverse(this);
mInGlobalScope = true;
decrementDepth();
}
if (visit && postVisit)
visitFunctionDefinition(PostVisit, node);
}
// Traverse a block node.
void TIntermTraverser::traverseBlock(TIntermBlock *node)
{
......@@ -473,9 +504,6 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
{
incrementDepth(node);
if (node->getOp() == EOpFunction)
mInGlobalScope = false;
for (auto *child : *sequence)
{
child->traverse(this);
......@@ -486,9 +514,6 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
}
}
if (node->getOp() == EOpFunction)
mInGlobalScope = true;
decrementDepth();
}
......@@ -496,26 +521,24 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
visitAggregate(PostVisit, node);
}
void TLValueTrackingTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
{
TIntermAggregate *params = node->getFunctionParameters();
ASSERT(params != nullptr);
ASSERT(params->getOp() == EOpParameters);
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), params->getSequence());
TIntermTraverser::traverseFunctionDefinition(node);
}
void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
{
bool visit = true;
TIntermSequence *sequence = node->getSequence();
switch (node->getOp())
if (node->getOp() == EOpPrototype)
{
case EOpFunction:
{
TIntermAggregate *params = sequence->front()->getAsAggregate();
ASSERT(params != nullptr);
ASSERT(params->getOp() == EOpParameters);
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), params->getSequence());
break;
}
case EOpPrototype:
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence);
break;
default:
break;
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence);
}
if (preVisit)
......@@ -561,9 +584,6 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
}
else
{
if (node->getOp() == EOpFunction)
mInGlobalScope = false;
// Find the built-in function corresponding to this op so that we can determine the
// in/out qualifiers of its parameters.
TFunction *builtInFunc = nullptr;
......@@ -609,9 +629,6 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
}
setInFunctionCallOutParameter(false);
if (node->getOp() == EOpFunction)
mInGlobalScope = true;
}
decrementDepth();
......
......@@ -14,7 +14,6 @@ enum TOperator
{
EOpNull, // if in a node, should only mean a node is still being built
EOpFunctionCall,
EOpFunction, // For function definition
EOpParameters, // an aggregate listing the parameters to a function
EOpDeclaration,
......
......@@ -22,9 +22,9 @@ TString arrayBrackets(const TType &type)
bool isSingleStatement(TIntermNode *node)
{
if (const TIntermAggregate *aggregate = node->getAsAggregate())
if (node->getAsFunctionDefinition())
{
return (aggregate->getOp() != EOpFunction);
return false;
}
else if (node->getAsBlock())
{
......@@ -786,6 +786,35 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
return false;
}
bool TOutputGLSLBase::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
TInfoSinkBase &out = objSink();
ASSERT(visit == PreVisit);
{
const TType &type = node->getType();
writeVariableType(type);
if (type.isArray())
out << arrayBrackets(type);
}
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
incrementDepth(node);
// Traverse function parameters.
TIntermAggregate *params = node->getFunctionParameters()->getAsAggregate();
ASSERT(params->getOp() == EOpParameters);
params->traverse(this);
// Traverse function body.
visitCodeBlock(node->getBody());
decrementDepth();
// Fully processed; no need to visit children.
return false;
}
bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
{
bool visitChildren = true;
......@@ -811,39 +840,6 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
visitChildren = false;
break;
case EOpFunction: {
// Function definition.
ASSERT(visit == PreVisit);
{
const TType &type = node->getType();
writeVariableType(type);
if (type.isArray())
out << arrayBrackets(type);
}
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
incrementDepth(node);
// Function definition node contains two child nodes representing the function parameters
// and the function body.
const TIntermSequence &sequence = *(node->getSequence());
ASSERT(sequence.size() == 2);
// Traverse function parameters.
TIntermAggregate *params = sequence[0]->getAsAggregate();
ASSERT(params != NULL);
ASSERT(params->getOp() == EOpParameters);
params->traverse(this);
// Traverse function body.
TIntermBlock *body = sequence[1]->getAsBlock();
visitCodeBlock(body);
decrementDepth();
// Fully processed; no need to visit children.
visitChildren = false;
break;
}
case EOpFunctionCall:
// Function call.
if (visit == PreVisit)
......
......@@ -49,6 +49,7 @@ class TOutputGLSLBase : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBlock(Visit visit, TIntermBlock *node) override;
bool visitLoop(Visit visit, TIntermLoop *node) override;
......
......@@ -1449,6 +1449,70 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
return false;
}
bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
TInfoSinkBase &out = getInfoSink();
ASSERT(mCurrentFunctionMetadata == nullptr);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
ASSERT(index != CallDAG::InvalidIndex);
mCurrentFunctionMetadata = &mASTMetadataList[index];
out << TypeString(node->getType()) << " ";
TIntermSequence *parameters = node->getFunctionParameters()->getSequence();
if (node->getFunctionSymbolInfo()->isMain())
{
out << "gl_main(";
}
else
{
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj())
<< DisambiguateFunctionName(parameters) << (mOutputLod0Function ? "Lod0(" : "(");
}
for (unsigned int i = 0; i < parameters->size(); i++)
{
TIntermSymbol *symbol = (*parameters)[i]->getAsSymbolNode();
if (symbol)
{
ensureStructDefined(symbol->getType());
out << argumentString(symbol);
if (i < parameters->size() - 1)
{
out << ", ";
}
}
else
UNREACHABLE();
}
out << ")\n";
mInsideFunction = true;
// The function body node will output braces.
node->getBody()->traverse(this);
mInsideFunction = false;
mCurrentFunctionMetadata = nullptr;
bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
{
ASSERT(!node->getFunctionSymbolInfo()->isMain());
mOutputLod0Function = true;
node->traverse(this);
mOutputLod0Function = false;
}
return false;
}
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{
TInfoSinkBase &out = getInfoSink();
......@@ -1581,70 +1645,6 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpComma:
outputTriplet(out, visit, "(", ", ", ")");
break;
case EOpFunction:
{
ASSERT(mCurrentFunctionMetadata == nullptr);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
ASSERT(index != CallDAG::InvalidIndex);
mCurrentFunctionMetadata = &mASTMetadataList[index];
out << TypeString(node->getType()) << " ";
TIntermSequence *sequence = node->getSequence();
TIntermSequence *arguments = (*sequence)[0]->getAsAggregate()->getSequence();
if (node->getFunctionSymbolInfo()->isMain())
{
out << "gl_main(";
}
else
{
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj())
<< DisambiguateFunctionName(arguments) << (mOutputLod0Function ? "Lod0(" : "(");
}
for (unsigned int i = 0; i < arguments->size(); i++)
{
TIntermSymbol *symbol = (*arguments)[i]->getAsSymbolNode();
if (symbol)
{
ensureStructDefined(symbol->getType());
out << argumentString(symbol);
if (i < arguments->size() - 1)
{
out << ", ";
}
}
else UNREACHABLE();
}
out << ")\n";
mInsideFunction = true;
ASSERT(sequence->size() == 2);
TIntermNode *body = (*sequence)[1];
// The function body node will output braces.
ASSERT(body->getAsBlock() != nullptr);
body->traverse(this);
mInsideFunction = false;
mCurrentFunctionMetadata = nullptr;
bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
{
ASSERT(!node->getFunctionSymbolInfo()->isMain());
mOutputLod0Function = true;
node->traverse(this);
mOutputLod0Function = false;
}
return false;
}
case EOpFunctionCall:
{
TIntermSequence *arguments = node->getSequence();
......
......@@ -66,6 +66,7 @@ class OutputHLSL : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *);
bool visitSwitch(Visit visit, TIntermSwitch *);
bool visitCase(Visit visit, TIntermCase *);
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate*);
bool visitBlock(Visit visit, TIntermBlock *node);
bool visitLoop(Visit visit, TIntermLoop*);
......
......@@ -2059,10 +2059,11 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction
return prototype;
}
TIntermAggregate *TParseContext::addFunctionDefinition(const TFunction &function,
TIntermAggregate *functionPrototype,
TIntermBlock *functionBody,
const TSourceLoc &location)
TIntermFunctionDefinition *TParseContext::addFunctionDefinition(
const TFunction &function,
TIntermAggregate *functionParameters,
TIntermBlock *functionBody,
const TSourceLoc &location)
{
// Check that non-void functions have at least one return statement.
if (mCurrentFunctionType->getBasicType() != EbtVoid && !mFunctionReturnsValue)
......@@ -2070,21 +2071,16 @@ TIntermAggregate *TParseContext::addFunctionDefinition(const TFunction &function
error(location, "function does not return a value:", "", function.getName().c_str());
}
TIntermAggregate *functionNode = new TIntermAggregate(EOpFunction);
functionNode->setLine(location);
ASSERT(functionPrototype != nullptr);
functionNode->getSequence()->push_back(functionPrototype);
if (functionBody == nullptr)
{
functionBody = new TIntermBlock();
functionBody->setLine(location);
}
functionNode->getSequence()->push_back(functionBody);
TIntermFunctionDefinition *functionNode =
new TIntermFunctionDefinition(function.getReturnType(), functionParameters, functionBody);
functionNode->setLine(location);
functionNode->getFunctionSymbolInfo()->setFromFunction(function);
functionNode->setType(function.getReturnType());
symbolTable.pop();
return functionNode;
......
......@@ -266,10 +266,10 @@ class TParseContext : angle::NonCopyable
void parseGlobalLayoutQualifier(const TTypeQualifierBuilder &typeQualifierBuilder);
TIntermAggregate *addFunctionPrototypeDeclaration(const TFunction &parsedFunction,
const TSourceLoc &location);
TIntermAggregate *addFunctionDefinition(const TFunction &function,
TIntermAggregate *functionPrototype,
TIntermBlock *functionBody,
const TSourceLoc &location);
TIntermFunctionDefinition *addFunctionDefinition(const TFunction &function,
TIntermAggregate *functionParameters,
TIntermBlock *functionBody,
const TSourceLoc &location);
void parseFunctionDefinitionHeader(const TSourceLoc &location,
TFunction **function,
TIntermAggregate **aggregateOut);
......
......@@ -172,7 +172,7 @@ TType GetFieldType(const TType &indexedType)
// base[1] = value;
// }
// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write)
{
ASSERT(!type.isArray());
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
......@@ -180,8 +180,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends.
type.setPrecision(EbpHigh);
TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
TType fieldType = GetFieldType(type);
int numCases = 0;
......@@ -193,14 +191,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
{
numCases = type.getNominalSize();
}
if (write)
{
indexingFunction->setType(TType(EbtVoid));
}
else
{
indexingFunction->setType(fieldType);
}
TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
TQualifier baseQualifier = EvqInOut;
......@@ -215,7 +205,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
paramsNode->getSequence()->push_back(valueParam);
}
indexingFunction->getSequence()->push_back(paramsNode);
TIntermBlock *statementList = new TIntermBlock();
for (int i = 0; i < numCases; ++i)
......@@ -284,8 +273,16 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
bodyNode->getSequence()->push_back(ifNode);
bodyNode->getSequence()->push_back(useLastBlock);
indexingFunction->getSequence()->push_back(bodyNode);
TIntermFunctionDefinition *indexingFunction = nullptr;
if (write)
{
indexingFunction = new TIntermFunctionDefinition(TType(EbtVoid), paramsNode, bodyNode);
}
else
{
indexingFunction = new TIntermFunctionDefinition(fieldType, paramsNode, bodyNode);
}
indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
return indexingFunction;
}
......
......@@ -25,7 +25,7 @@ class ElseBlockRewriter : public TIntermTraverser
ElseBlockRewriter();
protected:
bool visitAggregate(Visit visit, TIntermAggregate *aggregate) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *aggregate) override;
bool visitBlock(Visit visit, TIntermBlock *block) override;
private:
......@@ -39,13 +39,10 @@ ElseBlockRewriter::ElseBlockRewriter()
mFunctionType(NULL)
{}
bool ElseBlockRewriter::visitAggregate(Visit visit, TIntermAggregate *node)
bool ElseBlockRewriter::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
if (node->getOp() == EOpFunction)
{
// Store the current function context (see comment below)
mFunctionType = ((visit == PreVisit) ? &node->getType() : nullptr);
}
// Store the current function context (see comment below)
mFunctionType = ((visit == PreVisit) ? &node->getType() : nullptr);
return true;
}
......
......@@ -94,7 +94,7 @@ bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggrega
// If we're outside a loop condition, we only need to traverse nodes that may contain loops.
if (!mInsideLoopConditionOrExpression)
return (node->getOp() == EOpFunction);
return false;
mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode());
return !mFoundLoopToChange;
......
......@@ -28,7 +28,8 @@ class UseUniformBlockMembers : public TIntermTraverser
}
protected:
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override { return !mCodeInserted; }
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
private:
void insertUseCode(TIntermSequence *sequence);
......@@ -38,31 +39,18 @@ class UseUniformBlockMembers : public TIntermTraverser
bool mCodeInserted;
};
bool UseUniformBlockMembers::visitAggregate(Visit visit, TIntermAggregate *node)
bool UseUniformBlockMembers::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
bool visitChildren = !mCodeInserted;
switch (node->getOp())
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
case EOpFunction:
{
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 2);
TIntermBlock *body = (*sequence)[1]->getAsBlock();
ASSERT(body);
insertUseCode(body->getSequence());
mCodeInserted = true;
visitChildren = false;
}
break;
}
default:
visitChildren = false;
break;
TIntermBlock *body = node->getBody();
ASSERT(body);
insertUseCode(body->getSequence());
mCodeInserted = true;
return false;
}
return visitChildren;
return !mCodeInserted;
}
void UseUniformBlockMembers::AddFieldUseStatements(const ShaderVariable &var,
......
......@@ -10,11 +10,11 @@
namespace
{
void OutputFunction(TInfoSinkBase &out, const char *str, TIntermAggregate *node)
void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *info)
{
const char *internal =
node->getFunctionSymbolInfo()->getNameObj().isInternal() ? " (internal function)" : "";
out << str << internal << ": " << node->getFunctionSymbolInfo()->getNameObj().getString();
const char *internal = info->getNameObj().isInternal() ? " (internal function)" : "";
out << str << internal << ": " << info->getNameObj().getString() << " (symbol id "
<< info->getId() << ")";
}
//
......@@ -48,6 +48,7 @@ class TOutputTraverser : public TIntermTraverser
bool visitUnary(Visit visit, TIntermUnary *) override;
bool visitTernary(Visit visit, TIntermTernary *node) override;
bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitBlock(Visit visit, TIntermBlock *) override;
bool visitLoop(Visit visit, TIntermLoop *) override;
......@@ -372,6 +373,15 @@ bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node)
return true;
}
bool TOutputTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
TInfoSinkBase &out = sink;
OutputTreeText(out, node, mDepth);
OutputFunction(out, "Function Definition", node->getFunctionSymbolInfo());
out << "\n";
return true;
}
bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
TInfoSinkBase &out = sink;
......@@ -389,10 +399,13 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
switch (node->getOp())
{
case EOpComma: out << "Comma\n"; return true;
case EOpFunction: OutputFunction(out, "Function Definition", node); break;
case EOpFunctionCall: OutputFunction(out, "Function Call", node); break;
case EOpFunctionCall:
OutputFunction(out, "Function Call", node->getFunctionSymbolInfo());
break;
case EOpParameters: out << "Function Parameters: "; break;
case EOpPrototype: OutputFunction(out, "Function Prototype", node); break;
case EOpPrototype:
OutputFunction(out, "Function Prototype", node->getFunctionSymbolInfo());
break;
case EOpConstructFloat: out << "Construct float"; break;
case EOpConstructVec2: out << "Construct vec2"; break;
......
......@@ -39,7 +39,7 @@ class TypeTrackingTest : public testing::Test
const char *shaderStrings[] = { shaderString.c_str() };
bool compilationSuccess = mTranslator->compile(shaderStrings, 1, SH_INTERMEDIATE_TREE);
TInfoSink &infoSink = mTranslator->getInfoSink();
mInfoLog = infoSink.info.c_str();
mInfoLog = RemoveSymbolIdsFromInfoLog(infoSink.info.c_str());
if (!compilationSuccess)
FAIL() << "Shader compilation failed " << mInfoLog;
}
......@@ -55,6 +55,23 @@ class TypeTrackingTest : public testing::Test
}
private:
// Remove symbol ids from info log - the tests don't care about them.
static std::string RemoveSymbolIdsFromInfoLog(const char *infoLog)
{
std::string filteredLog(infoLog);
size_t idPrefixPos = 0u;
do
{
idPrefixPos = filteredLog.find(" (symbol id");
if (idPrefixPos != std::string::npos)
{
size_t idSuffixPos = filteredLog.find(")", idPrefixPos);
filteredLog.erase(idPrefixPos, idSuffixPos - idPrefixPos + 1u);
}
} while (idPrefixPos != std::string::npos);
return filteredLog;
}
TranslatorESSL *mTranslator;
std::string mInfoLog;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment