Commit 336b1470 by Olli Etuaho Committed by Commit Bot

Split TIntermFunctionDefinition from TIntermAggregate

This makes the code easier to understand. Function definition nodes always have just two children, the parameters node and the function body node, so there was no proper reason why they should be aggregate nodes. As a part of this change, intermediate output is modified to print symbol table ids of functions so that debugging function id related functionality will be easier in the future. After this patch, TIntermAggregate is still used for function prototypes, function parameter lists, function calls, variable and invariant declarations and the comma (sequence) operator. BUG=angleproject:1490 TEST=angle_unittests, angle_end2end_tests Change-Id: Ib88b4ca5d21abd5f126836ca5900d0baecabd19e Reviewed-on: https://chromium-review.googlesource.com/394707 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org>
parent 138ec92f
...@@ -31,7 +31,7 @@ class PullGradient : public TIntermTraverser ...@@ -31,7 +31,7 @@ class PullGradient : public TIntermTraverser
ASSERT(index < metadataList->size()); ASSERT(index < metadataList->size());
} }
void traverse(TIntermAggregate *node) void traverse(TIntermFunctionDefinition *node)
{ {
node->traverse(this); node->traverse(this);
ASSERT(mParents.empty()); ASSERT(mParents.empty());
...@@ -158,7 +158,7 @@ class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser ...@@ -158,7 +158,7 @@ class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
{ {
} }
void traverse(TIntermAggregate *node) void traverse(TIntermFunctionDefinition *node)
{ {
node->traverse(this); node->traverse(this);
ASSERT(mLoopsAndSwitches.empty()); ASSERT(mLoopsAndSwitches.empty());
...@@ -328,7 +328,7 @@ class PushDiscontinuousLoops : public TIntermTraverser ...@@ -328,7 +328,7 @@ class PushDiscontinuousLoops : public TIntermTraverser
{ {
} }
void traverse(TIntermAggregate *node) void traverse(TIntermFunctionDefinition *node)
{ {
node->traverse(this); node->traverse(this);
ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)); ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
......
...@@ -31,16 +31,15 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser ...@@ -31,16 +31,15 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser
private: private:
AddDefaultReturnStatementsTraverser() : TIntermTraverser(true, false, false) {} AddDefaultReturnStatementsTraverser() : TIntermTraverser(true, false, false) {}
static bool IsFunctionWithoutReturnStatement(TIntermAggregate *node, TType *returnType) static bool IsFunctionWithoutReturnStatement(TIntermFunctionDefinition *node, TType *returnType)
{ {
*returnType = node->getType(); *returnType = node->getType();
if (node->getOp() != EOpFunction || node->getType().getBasicType() == EbtVoid) if (node->getType().getBasicType() == EbtVoid)
{ {
return false; return false;
} }
TIntermBlock *bodyNode = node->getSequence()->back()->getAsBlock(); TIntermBlock *bodyNode = node->getBody();
ASSERT(bodyNode);
TIntermBranch *returnNode = bodyNode->getSequence()->back()->getAsBranchNode(); TIntermBranch *returnNode = bodyNode->getSequence()->back()->getAsBranchNode();
if (returnNode != nullptr && returnNode->getFlowOp() == EOpReturn) if (returnNode != nullptr && returnNode->getFlowOp() == EOpReturn)
{ {
...@@ -50,7 +49,7 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser ...@@ -50,7 +49,7 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser
return true; return true;
} }
bool visitAggregate(Visit, TIntermAggregate *node) override bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *node) override
{ {
TType returnType; TType returnType;
if (IsFunctionWithoutReturnStatement(node, &returnType)) if (IsFunctionWithoutReturnStatement(node, &returnType))
...@@ -58,7 +57,7 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser ...@@ -58,7 +57,7 @@ class AddDefaultReturnStatementsTraverser : private TIntermTraverser
TIntermBranch *branch = TIntermBranch *branch =
new TIntermBranch(EOpReturn, TIntermTyped::CreateZero(returnType)); new TIntermBranch(EOpReturn, TIntermTyped::CreateZero(returnType));
TIntermBlock *bodyNode = node->getSequence()->back()->getAsBlock(); TIntermBlock *bodyNode = node->getBody();
bodyNode->getSequence()->push_back(branch); bodyNode->getSequence()->push_back(branch);
return false; return false;
......
...@@ -60,6 +60,7 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser ...@@ -60,6 +60,7 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
private: private:
ArrayReturnValueToOutParameterTraverser(); ArrayReturnValueToOutParameterTraverser();
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBranch(Visit visit, TIntermBranch *node) override; bool visitBranch(Visit visit, TIntermBranch *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
...@@ -81,35 +82,47 @@ ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser ...@@ -81,35 +82,47 @@ ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser
{ {
} }
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
Visit visit,
TIntermFunctionDefinition *node)
{ {
if (visit == PreVisit) if (node->isArray() && visit == PreVisit)
{ {
if (node->isArray()) // Replace the parameters child node of the function definition with another node
{ // that has the out parameter added.
if (node->getOp() == EOpFunction) // Also set the function to return void.
{
// Replace the parameters child node of the function definition with another node
// that has the out parameter added.
// Also set the function to return void.
TIntermAggregate *params = node->getSequence()->front()->getAsAggregate(); TIntermAggregate *params = node->getFunctionParameters();
ASSERT(params != nullptr && params->getOp() == EOpParameters); ASSERT(params != nullptr && params->getOp() == EOpParameters);
TIntermAggregate *replacementParams = new TIntermAggregate; TIntermAggregate *replacementParams = new TIntermAggregate;
replacementParams->setOp(EOpParameters); replacementParams->setOp(EOpParameters);
CopyAggregateChildren(params, replacementParams); CopyAggregateChildren(params, replacementParams);
replacementParams->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType())); replacementParams->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
replacementParams->setLine(params->getLine()); replacementParams->setLine(params->getLine());
queueReplacementWithParent(node, params, replacementParams, queueReplacementWithParent(node, params, replacementParams, OriginalNode::IS_DROPPED);
OriginalNode::IS_DROPPED);
node->setType(TType(EbtVoid)); node->setType(TType(EbtVoid));
mInFunctionWithArrayReturnValue = true; mInFunctionWithArrayReturnValue = true;
} }
else if (node->getOp() == EOpPrototype) if (visit == PostVisit)
{
// This isn't conditional on node->isArray() since the type has already been changed on
// PreVisit.
mInFunctionWithArrayReturnValue = false;
}
return true;
}
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
if (visit == PreVisit)
{
if (node->isArray())
{
if (node->getOp() == EOpPrototype)
{ {
// Replace the whole prototype node with another node that has the out parameter added. // Replace the whole prototype node with another node that has the out parameter added.
TIntermAggregate *replacement = new TIntermAggregate; TIntermAggregate *replacement = new TIntermAggregate;
...@@ -149,13 +162,6 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter ...@@ -149,13 +162,6 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter
} }
} }
} }
else if (visit == PostVisit)
{
if (node->getOp() == EOpFunction)
{
mInFunctionWithArrayReturnValue = false;
}
}
return true; return true;
} }
......
...@@ -94,13 +94,39 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -94,13 +94,39 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
} }
std::set<CreatorFunctionData*> callees; std::set<CreatorFunctionData*> callees;
TIntermAggregate *node; TIntermFunctionDefinition *node;
TString name; TString name;
size_t index; size_t index;
bool indexAssigned; bool indexAssigned;
bool visiting; bool visiting;
}; };
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
{
// Create the record if need be and remember the node.
if (visit == PreVisit)
{
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName());
if (it == mFunctions.end())
{
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getName()];
}
else
{
mCurrentFunction = &it->second;
}
mCurrentFunction->node = node;
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
}
else if (visit == PostVisit)
{
mCurrentFunction = nullptr;
}
return true;
}
// Aggregates the AST node for each function as well as the name of the functions called by it // Aggregates the AST node for each function as well as the name of the functions called by it
bool visitAggregate(Visit visit, TIntermAggregate *node) override bool visitAggregate(Visit visit, TIntermAggregate *node) override
{ {
...@@ -114,31 +140,6 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -114,31 +140,6 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
record.name = node->getFunctionSymbolInfo()->getName(); record.name = node->getFunctionSymbolInfo()->getName();
} }
break; break;
case EOpFunction:
{
// Function definition, create the record if need be and remember the node.
if (visit == PreVisit)
{
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName());
if (it == mFunctions.end())
{
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getName()];
}
else
{
mCurrentFunction = &it->second;
}
mCurrentFunction->node = node;
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
}
else if (visit == PostVisit)
{
mCurrentFunction = nullptr;
}
break;
}
case EOpFunctionCall: case EOpFunctionCall:
{ {
// Function call, add the callees // Function call, add the callees
......
...@@ -41,7 +41,7 @@ class CallDAG : angle::NonCopyable ...@@ -41,7 +41,7 @@ class CallDAG : angle::NonCopyable
struct Record struct Record
{ {
std::string name; std::string name;
TIntermAggregate *node; TIntermFunctionDefinition *node;
std::vector<int> callees; std::vector<int> callees;
}; };
......
...@@ -763,22 +763,31 @@ class TCompiler::UnusedPredicate ...@@ -763,22 +763,31 @@ class TCompiler::UnusedPredicate
bool operator ()(TIntermNode *node) bool operator ()(TIntermNode *node)
{ {
const TIntermAggregate *asAggregate = node->getAsAggregate(); const TIntermAggregate *asAggregate = node->getAsAggregate();
const TIntermFunctionDefinition *asFunction = node->getAsFunctionDefinition();
if (asAggregate == nullptr) const TFunctionSymbolInfo *functionInfo = nullptr;
if (asFunction)
{ {
return false; functionInfo = asFunction->getFunctionSymbolInfo();
} }
else if (asAggregate)
if (!(asAggregate->getOp() == EOpFunction || asAggregate->getOp() == EOpPrototype)) {
if (asAggregate->getOp() == EOpPrototype)
{
functionInfo = asAggregate->getFunctionSymbolInfo();
}
}
if (functionInfo == nullptr)
{ {
return false; return false;
} }
size_t callDagIndex = mCallDag->findIndex(asAggregate->getFunctionSymbolInfo()); size_t callDagIndex = mCallDag->findIndex(functionInfo);
if (callDagIndex == CallDAG::InvalidIndex) if (callDagIndex == CallDAG::InvalidIndex)
{ {
// This happens only for unimplemented prototypes which are thus unused // This happens only for unimplemented prototypes which are thus unused
ASSERT(asAggregate->getOp() == EOpPrototype); ASSERT(asAggregate && asAggregate->getOp() == EOpPrototype);
return true; return true;
} }
......
...@@ -38,18 +38,16 @@ TIntermAggregate *CreateFunctionPrototypeNode(const char *name, const int functi ...@@ -38,18 +38,16 @@ TIntermAggregate *CreateFunctionPrototypeNode(const char *name, const int functi
return functionNode; return functionNode;
} }
TIntermAggregate *CreateFunctionDefinitionNode(const char *name, TIntermFunctionDefinition *CreateFunctionDefinitionNode(const char *name,
TIntermBlock *functionBody, TIntermBlock *functionBody,
const int functionId) const int functionId)
{ {
TIntermAggregate *functionNode = new TIntermAggregate(EOpFunction); TType returnType(EbtVoid);
TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters); TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
functionNode->getSequence()->push_back(paramsNode); TIntermFunctionDefinition *functionNode =
functionNode->getSequence()->push_back(functionBody); new TIntermFunctionDefinition(returnType, paramsNode, functionBody);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name); SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
TType returnType(EbtVoid);
functionNode->setType(returnType);
functionNode->getFunctionSymbolInfo()->setId(functionId); functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode; return functionNode;
} }
...@@ -156,25 +154,22 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -156,25 +154,22 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
{ {
functionBody->push_back(deferredInit); functionBody->push_back(deferredInit);
} }
TIntermAggregate *functionDefinition = TIntermFunctionDefinition *functionDefinition =
CreateFunctionDefinitionNode(functionName, functionBodyNode, initFunctionId); CreateFunctionDefinitionNode(functionName, functionBodyNode, initFunctionId);
root->getSequence()->push_back(functionDefinition); root->getSequence()->push_back(functionDefinition);
// Insert call into main function // Insert call into main function
for (TIntermNode *node : *root->getSequence()) for (TIntermNode *node : *root->getSequence())
{ {
TIntermAggregate *nodeAgg = node->getAsAggregate(); TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition();
if (nodeAgg != nullptr && nodeAgg->getOp() == EOpFunction && if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain())
nodeAgg->getFunctionSymbolInfo()->isMain())
{ {
TIntermAggregate *functionCallNode = TIntermAggregate *functionCallNode =
CreateFunctionCallNode(functionName, initFunctionId); CreateFunctionCallNode(functionName, initFunctionId);
TIntermNode *mainBody = nodeAgg->getSequence()->back(); TIntermBlock *mainBody = nodeFunction->getBody();
TIntermBlock *mainBodyBlock = mainBody->getAsBlock(); ASSERT(mainBody != nullptr);
ASSERT(mainBodyBlock != nullptr); mainBody->getSequence()->insert(mainBody->getSequence()->begin(), functionCallNode);
mainBodyBlock->getSequence()->insert(mainBodyBlock->getSequence()->begin(),
functionCallNode);
} }
} }
} }
......
...@@ -34,7 +34,7 @@ class GLFragColorBroadcastTraverser : public TIntermTraverser ...@@ -34,7 +34,7 @@ class GLFragColorBroadcastTraverser : public TIntermTraverser
protected: protected:
void visitSymbol(TIntermSymbol *node) override; void visitSymbol(TIntermSymbol *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
TIntermBinary *constructGLFragDataNode(int index) const; TIntermBinary *constructGLFragDataNode(int index) const;
TIntermBinary *constructGLFragDataAssignNode(int index) const; TIntermBinary *constructGLFragDataAssignNode(int index) const;
...@@ -74,24 +74,15 @@ void GLFragColorBroadcastTraverser::visitSymbol(TIntermSymbol *node) ...@@ -74,24 +74,15 @@ void GLFragColorBroadcastTraverser::visitSymbol(TIntermSymbol *node)
} }
} }
bool GLFragColorBroadcastTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool GLFragColorBroadcastTraverser::visitFunctionDefinition(Visit visit,
TIntermFunctionDefinition *node)
{ {
switch (node->getOp()) ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{ {
case EOpFunction: TIntermBlock *body = node->getBody();
// Function definition. ASSERT(body);
ASSERT(visit == PreVisit); mMainSequence = body->getSequence();
if (node->getFunctionSymbolInfo()->isMain())
{
TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 2);
TIntermBlock *body = (*sequence)[1]->getAsBlock();
ASSERT(body);
mMainSequence = body->getSequence();
}
break;
default:
break;
} }
return true; return true;
} }
......
...@@ -604,7 +604,6 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -604,7 +604,6 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
switch (node->getOp()) switch (node->getOp())
{ {
case EOpConstructStruct: case EOpConstructStruct:
case EOpFunction:
break; break;
case EOpPrototype: case EOpPrototype:
visitChildren = false; visitChildren = false;
......
...@@ -28,8 +28,9 @@ class VariableInitializer : public TIntermTraverser ...@@ -28,8 +28,9 @@ class VariableInitializer : public TIntermTraverser
bool visitIfElse(Visit, TIntermIfElse *node) override { return false; } bool visitIfElse(Visit, TIntermIfElse *node) override { return false; }
bool visitLoop(Visit, TIntermLoop *node) override { return false; } bool visitLoop(Visit, TIntermLoop *node) override { return false; }
bool visitBranch(Visit, TIntermBranch *node) override { return false; } bool visitBranch(Visit, TIntermBranch *node) override { return false; }
bool visitAggregate(Visit, TIntermAggregate *node) override { return false; }
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
private: private:
void insertInitCode(TIntermSequence *sequence); void insertInitCode(TIntermSequence *sequence);
...@@ -40,31 +41,17 @@ class VariableInitializer : public TIntermTraverser ...@@ -40,31 +41,17 @@ class VariableInitializer : public TIntermTraverser
// VariableInitializer implementation. // VariableInitializer implementation.
bool VariableInitializer::visitAggregate(Visit visit, TIntermAggregate *node) bool VariableInitializer::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{ {
bool visitChildren = !mCodeInserted; // Function definition.
switch (node->getOp()) ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{ {
case EOpFunction: TIntermBlock *body = node->getBody();
{ insertInitCode(body->getSequence());
// Function definition. mCodeInserted = true;
ASSERT(visit == PreVisit);
if (node->getFunctionSymbolInfo()->isMain())
{
TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 2);
TIntermBlock *body = (*sequence)[1]->getAsBlock();
ASSERT(body);
insertInitCode(body->getSequence());
mCodeInserted = true;
}
break;
}
default:
visitChildren = false;
break;
} }
return visitChildren; return false;
} }
void VariableInitializer::insertInitCode(TIntermSequence *sequence) void VariableInitializer::insertInitCode(TIntermSequence *sequence)
......
...@@ -204,6 +204,13 @@ bool TIntermUnary::replaceChildNode( ...@@ -204,6 +204,13 @@ bool TIntermUnary::replaceChildNode(
return false; return false;
} }
bool TIntermFunctionDefinition::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mParameters, TIntermAggregate, original, replacement);
REPLACE_IF_IS(mBody, TIntermBlock, original, replacement);
return false;
}
bool TIntermAggregate::replaceChildNode( bool TIntermAggregate::replaceChildNode(
TIntermNode *original, TIntermNode *replacement) TIntermNode *original, TIntermNode *replacement)
{ {
......
...@@ -32,6 +32,7 @@ class TDiagnostics; ...@@ -32,6 +32,7 @@ class TDiagnostics;
class TIntermTraverser; class TIntermTraverser;
class TIntermAggregate; class TIntermAggregate;
class TIntermBlock; class TIntermBlock;
class TIntermFunctionDefinition;
class TIntermSwizzle; class TIntermSwizzle;
class TIntermBinary; class TIntermBinary;
class TIntermUnary; class TIntermUnary;
...@@ -94,6 +95,7 @@ class TIntermNode : angle::NonCopyable ...@@ -94,6 +95,7 @@ class TIntermNode : angle::NonCopyable
virtual void traverse(TIntermTraverser *) = 0; virtual void traverse(TIntermTraverser *) = 0;
virtual TIntermTyped *getAsTyped() { return 0; } virtual TIntermTyped *getAsTyped() { return 0; }
virtual TIntermConstantUnion *getAsConstantUnion() { return 0; } virtual TIntermConstantUnion *getAsConstantUnion() { return 0; }
virtual TIntermFunctionDefinition *getAsFunctionDefinition() { return nullptr; }
virtual TIntermAggregate *getAsAggregate() { return 0; } virtual TIntermAggregate *getAsAggregate() { return 0; }
virtual TIntermBlock *getAsBlock() { return nullptr; } virtual TIntermBlock *getAsBlock() { return nullptr; }
virtual TIntermSwizzle *getAsSwizzleNode() { return nullptr; } virtual TIntermSwizzle *getAsSwizzleNode() { return nullptr; }
...@@ -550,6 +552,47 @@ class TFunctionSymbolInfo ...@@ -550,6 +552,47 @@ class TFunctionSymbolInfo
int mId; int mId;
}; };
// Node for function definitions.
class TIntermFunctionDefinition : public TIntermTyped
{
public:
// TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor
// parameters.
TIntermFunctionDefinition(const TType &type, TIntermAggregate *parameters, TIntermBlock *body)
: TIntermTyped(type), mParameters(parameters), mBody(body)
{
ASSERT(parameters != nullptr);
ASSERT(body != nullptr);
}
TIntermFunctionDefinition *getAsFunctionDefinition() override { return this; }
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *deepCopy() const override
{
UNREACHABLE();
return nullptr;
}
bool hasSideEffects() const override
{
UNREACHABLE();
return true;
}
TIntermAggregate *getFunctionParameters() const { return mParameters; }
TIntermBlock *getBody() const { return mBody; }
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
private:
TIntermAggregate *mParameters;
TIntermBlock *mBody;
TFunctionSymbolInfo mFunctionInfo;
};
typedef TVector<TIntermNode *> TIntermSequence; typedef TVector<TIntermNode *> TIntermSequence;
typedef TVector<int> TQualifierList; typedef TVector<int> TQualifierList;
...@@ -805,6 +848,10 @@ class TIntermTraverser : angle::NonCopyable ...@@ -805,6 +848,10 @@ class TIntermTraverser : angle::NonCopyable
virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; } virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; } virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
virtual bool visitCase(Visit visit, TIntermCase *node) { return true; } virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
return true;
}
virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; } virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; }
virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; } virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; }
virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; } virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; }
...@@ -823,6 +870,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -823,6 +870,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void traverseIfElse(TIntermIfElse *node); virtual void traverseIfElse(TIntermIfElse *node);
virtual void traverseSwitch(TIntermSwitch *node); virtual void traverseSwitch(TIntermSwitch *node);
virtual void traverseCase(TIntermCase *node); virtual void traverseCase(TIntermCase *node);
virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
virtual void traverseAggregate(TIntermAggregate *node); virtual void traverseAggregate(TIntermAggregate *node);
virtual void traverseBlock(TIntermBlock *node); virtual void traverseBlock(TIntermBlock *node);
virtual void traverseLoop(TIntermLoop *node); virtual void traverseLoop(TIntermLoop *node);
...@@ -1039,6 +1087,7 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -1039,6 +1087,7 @@ class TLValueTrackingTraverser : public TIntermTraverser
void traverseBinary(TIntermBinary *node) final; void traverseBinary(TIntermBinary *node) final;
void traverseUnary(TIntermUnary *node) final; void traverseUnary(TIntermUnary *node) final;
void traverseFunctionDefinition(TIntermFunctionDefinition *node) final;
void traverseAggregate(TIntermAggregate *node) final; void traverseAggregate(TIntermAggregate *node) final;
protected: protected:
......
...@@ -58,6 +58,11 @@ void TIntermCase::traverse(TIntermTraverser *it) ...@@ -58,6 +58,11 @@ void TIntermCase::traverse(TIntermTraverser *it)
it->traverseCase(this); it->traverseCase(this);
} }
void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
{
it->traverseFunctionDefinition(this);
}
void TIntermBlock::traverse(TIntermTraverser *it) void TIntermBlock::traverse(TIntermTraverser *it)
{ {
it->traverseBlock(this); it->traverseBlock(this);
...@@ -424,6 +429,32 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node) ...@@ -424,6 +429,32 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
visitUnary(PostVisit, node); visitUnary(PostVisit, node);
} }
// Traverse a function definition node.
void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
{
bool visit = true;
if (preVisit)
visit = visitFunctionDefinition(PreVisit, node);
if (visit)
{
incrementDepth(node);
mInGlobalScope = false;
node->getFunctionParameters()->traverse(this);
if (inVisit)
visit = visitFunctionDefinition(InVisit, node);
node->getBody()->traverse(this);
mInGlobalScope = true;
decrementDepth();
}
if (visit && postVisit)
visitFunctionDefinition(PostVisit, node);
}
// Traverse a block node. // Traverse a block node.
void TIntermTraverser::traverseBlock(TIntermBlock *node) void TIntermTraverser::traverseBlock(TIntermBlock *node)
{ {
...@@ -473,9 +504,6 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -473,9 +504,6 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
{ {
incrementDepth(node); incrementDepth(node);
if (node->getOp() == EOpFunction)
mInGlobalScope = false;
for (auto *child : *sequence) for (auto *child : *sequence)
{ {
child->traverse(this); child->traverse(this);
...@@ -486,9 +514,6 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -486,9 +514,6 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
} }
} }
if (node->getOp() == EOpFunction)
mInGlobalScope = true;
decrementDepth(); decrementDepth();
} }
...@@ -496,26 +521,24 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -496,26 +521,24 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
visitAggregate(PostVisit, node); visitAggregate(PostVisit, node);
} }
void TLValueTrackingTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
{
TIntermAggregate *params = node->getFunctionParameters();
ASSERT(params != nullptr);
ASSERT(params->getOp() == EOpParameters);
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), params->getSequence());
TIntermTraverser::traverseFunctionDefinition(node);
}
void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
{ {
bool visit = true; bool visit = true;
TIntermSequence *sequence = node->getSequence(); TIntermSequence *sequence = node->getSequence();
switch (node->getOp()) if (node->getOp() == EOpPrototype)
{ {
case EOpFunction: addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence);
{
TIntermAggregate *params = sequence->front()->getAsAggregate();
ASSERT(params != nullptr);
ASSERT(params->getOp() == EOpParameters);
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), params->getSequence());
break;
}
case EOpPrototype:
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence);
break;
default:
break;
} }
if (preVisit) if (preVisit)
...@@ -561,9 +584,6 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -561,9 +584,6 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
} }
else else
{ {
if (node->getOp() == EOpFunction)
mInGlobalScope = false;
// Find the built-in function corresponding to this op so that we can determine the // Find the built-in function corresponding to this op so that we can determine the
// in/out qualifiers of its parameters. // in/out qualifiers of its parameters.
TFunction *builtInFunc = nullptr; TFunction *builtInFunc = nullptr;
...@@ -609,9 +629,6 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -609,9 +629,6 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
} }
setInFunctionCallOutParameter(false); setInFunctionCallOutParameter(false);
if (node->getOp() == EOpFunction)
mInGlobalScope = true;
} }
decrementDepth(); decrementDepth();
......
...@@ -14,7 +14,6 @@ enum TOperator ...@@ -14,7 +14,6 @@ enum TOperator
{ {
EOpNull, // if in a node, should only mean a node is still being built EOpNull, // if in a node, should only mean a node is still being built
EOpFunctionCall, EOpFunctionCall,
EOpFunction, // For function definition
EOpParameters, // an aggregate listing the parameters to a function EOpParameters, // an aggregate listing the parameters to a function
EOpDeclaration, EOpDeclaration,
......
...@@ -22,9 +22,9 @@ TString arrayBrackets(const TType &type) ...@@ -22,9 +22,9 @@ TString arrayBrackets(const TType &type)
bool isSingleStatement(TIntermNode *node) bool isSingleStatement(TIntermNode *node)
{ {
if (const TIntermAggregate *aggregate = node->getAsAggregate()) if (node->getAsFunctionDefinition())
{ {
return (aggregate->getOp() != EOpFunction); return false;
} }
else if (node->getAsBlock()) else if (node->getAsBlock())
{ {
...@@ -786,6 +786,35 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node) ...@@ -786,6 +786,35 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
return false; return false;
} }
bool TOutputGLSLBase::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
TInfoSinkBase &out = objSink();
ASSERT(visit == PreVisit);
{
const TType &type = node->getType();
writeVariableType(type);
if (type.isArray())
out << arrayBrackets(type);
}
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
incrementDepth(node);
// Traverse function parameters.
TIntermAggregate *params = node->getFunctionParameters()->getAsAggregate();
ASSERT(params->getOp() == EOpParameters);
params->traverse(this);
// Traverse function body.
visitCodeBlock(node->getBody());
decrementDepth();
// Fully processed; no need to visit children.
return false;
}
bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node) bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
bool visitChildren = true; bool visitChildren = true;
...@@ -811,39 +840,6 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -811,39 +840,6 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
visitChildren = false; visitChildren = false;
break; break;
case EOpFunction: {
// Function definition.
ASSERT(visit == PreVisit);
{
const TType &type = node->getType();
writeVariableType(type);
if (type.isArray())
out << arrayBrackets(type);
}
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
incrementDepth(node);
// Function definition node contains two child nodes representing the function parameters
// and the function body.
const TIntermSequence &sequence = *(node->getSequence());
ASSERT(sequence.size() == 2);
// Traverse function parameters.
TIntermAggregate *params = sequence[0]->getAsAggregate();
ASSERT(params != NULL);
ASSERT(params->getOp() == EOpParameters);
params->traverse(this);
// Traverse function body.
TIntermBlock *body = sequence[1]->getAsBlock();
visitCodeBlock(body);
decrementDepth();
// Fully processed; no need to visit children.
visitChildren = false;
break;
}
case EOpFunctionCall: case EOpFunctionCall:
// Function call. // Function call.
if (visit == PreVisit) if (visit == PreVisit)
......
...@@ -49,6 +49,7 @@ class TOutputGLSLBase : public TIntermTraverser ...@@ -49,6 +49,7 @@ class TOutputGLSLBase : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *node) override; bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override; bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override; bool visitCase(Visit visit, TIntermCase *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBlock(Visit visit, TIntermBlock *node) override; bool visitBlock(Visit visit, TIntermBlock *node) override;
bool visitLoop(Visit visit, TIntermLoop *node) override; bool visitLoop(Visit visit, TIntermLoop *node) override;
......
...@@ -1449,6 +1449,70 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node) ...@@ -1449,6 +1449,70 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
return false; return false;
} }
bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
TInfoSinkBase &out = getInfoSink();
ASSERT(mCurrentFunctionMetadata == nullptr);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
ASSERT(index != CallDAG::InvalidIndex);
mCurrentFunctionMetadata = &mASTMetadataList[index];
out << TypeString(node->getType()) << " ";
TIntermSequence *parameters = node->getFunctionParameters()->getSequence();
if (node->getFunctionSymbolInfo()->isMain())
{
out << "gl_main(";
}
else
{
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj())
<< DisambiguateFunctionName(parameters) << (mOutputLod0Function ? "Lod0(" : "(");
}
for (unsigned int i = 0; i < parameters->size(); i++)
{
TIntermSymbol *symbol = (*parameters)[i]->getAsSymbolNode();
if (symbol)
{
ensureStructDefined(symbol->getType());
out << argumentString(symbol);
if (i < parameters->size() - 1)
{
out << ", ";
}
}
else
UNREACHABLE();
}
out << ")\n";
mInsideFunction = true;
// The function body node will output braces.
node->getBody()->traverse(this);
mInsideFunction = false;
mCurrentFunctionMetadata = nullptr;
bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
{
ASSERT(!node->getFunctionSymbolInfo()->isMain());
mOutputLod0Function = true;
node->traverse(this);
mOutputLod0Function = false;
}
return false;
}
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
...@@ -1581,70 +1645,6 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1581,70 +1645,6 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpComma: case EOpComma:
outputTriplet(out, visit, "(", ", ", ")"); outputTriplet(out, visit, "(", ", ", ")");
break; break;
case EOpFunction:
{
ASSERT(mCurrentFunctionMetadata == nullptr);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
ASSERT(index != CallDAG::InvalidIndex);
mCurrentFunctionMetadata = &mASTMetadataList[index];
out << TypeString(node->getType()) << " ";
TIntermSequence *sequence = node->getSequence();
TIntermSequence *arguments = (*sequence)[0]->getAsAggregate()->getSequence();
if (node->getFunctionSymbolInfo()->isMain())
{
out << "gl_main(";
}
else
{
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj())
<< DisambiguateFunctionName(arguments) << (mOutputLod0Function ? "Lod0(" : "(");
}
for (unsigned int i = 0; i < arguments->size(); i++)
{
TIntermSymbol *symbol = (*arguments)[i]->getAsSymbolNode();
if (symbol)
{
ensureStructDefined(symbol->getType());
out << argumentString(symbol);
if (i < arguments->size() - 1)
{
out << ", ";
}
}
else UNREACHABLE();
}
out << ")\n";
mInsideFunction = true;
ASSERT(sequence->size() == 2);
TIntermNode *body = (*sequence)[1];
// The function body node will output braces.
ASSERT(body->getAsBlock() != nullptr);
body->traverse(this);
mInsideFunction = false;
mCurrentFunctionMetadata = nullptr;
bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
{
ASSERT(!node->getFunctionSymbolInfo()->isMain());
mOutputLod0Function = true;
node->traverse(this);
mOutputLod0Function = false;
}
return false;
}
case EOpFunctionCall: case EOpFunctionCall:
{ {
TIntermSequence *arguments = node->getSequence(); TIntermSequence *arguments = node->getSequence();
......
...@@ -66,6 +66,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -66,6 +66,7 @@ class OutputHLSL : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *); bool visitIfElse(Visit visit, TIntermIfElse *);
bool visitSwitch(Visit visit, TIntermSwitch *); bool visitSwitch(Visit visit, TIntermSwitch *);
bool visitCase(Visit visit, TIntermCase *); bool visitCase(Visit visit, TIntermCase *);
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate*); bool visitAggregate(Visit visit, TIntermAggregate*);
bool visitBlock(Visit visit, TIntermBlock *node); bool visitBlock(Visit visit, TIntermBlock *node);
bool visitLoop(Visit visit, TIntermLoop*); bool visitLoop(Visit visit, TIntermLoop*);
......
...@@ -2059,10 +2059,11 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction ...@@ -2059,10 +2059,11 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction
return prototype; return prototype;
} }
TIntermAggregate *TParseContext::addFunctionDefinition(const TFunction &function, TIntermFunctionDefinition *TParseContext::addFunctionDefinition(
TIntermAggregate *functionPrototype, const TFunction &function,
TIntermBlock *functionBody, TIntermAggregate *functionParameters,
const TSourceLoc &location) TIntermBlock *functionBody,
const TSourceLoc &location)
{ {
// Check that non-void functions have at least one return statement. // Check that non-void functions have at least one return statement.
if (mCurrentFunctionType->getBasicType() != EbtVoid && !mFunctionReturnsValue) if (mCurrentFunctionType->getBasicType() != EbtVoid && !mFunctionReturnsValue)
...@@ -2070,21 +2071,16 @@ TIntermAggregate *TParseContext::addFunctionDefinition(const TFunction &function ...@@ -2070,21 +2071,16 @@ TIntermAggregate *TParseContext::addFunctionDefinition(const TFunction &function
error(location, "function does not return a value:", "", function.getName().c_str()); error(location, "function does not return a value:", "", function.getName().c_str());
} }
TIntermAggregate *functionNode = new TIntermAggregate(EOpFunction);
functionNode->setLine(location);
ASSERT(functionPrototype != nullptr);
functionNode->getSequence()->push_back(functionPrototype);
if (functionBody == nullptr) if (functionBody == nullptr)
{ {
functionBody = new TIntermBlock(); functionBody = new TIntermBlock();
functionBody->setLine(location); functionBody->setLine(location);
} }
functionNode->getSequence()->push_back(functionBody); TIntermFunctionDefinition *functionNode =
new TIntermFunctionDefinition(function.getReturnType(), functionParameters, functionBody);
functionNode->setLine(location);
functionNode->getFunctionSymbolInfo()->setFromFunction(function); functionNode->getFunctionSymbolInfo()->setFromFunction(function);
functionNode->setType(function.getReturnType());
symbolTable.pop(); symbolTable.pop();
return functionNode; return functionNode;
......
...@@ -266,10 +266,10 @@ class TParseContext : angle::NonCopyable ...@@ -266,10 +266,10 @@ class TParseContext : angle::NonCopyable
void parseGlobalLayoutQualifier(const TTypeQualifierBuilder &typeQualifierBuilder); void parseGlobalLayoutQualifier(const TTypeQualifierBuilder &typeQualifierBuilder);
TIntermAggregate *addFunctionPrototypeDeclaration(const TFunction &parsedFunction, TIntermAggregate *addFunctionPrototypeDeclaration(const TFunction &parsedFunction,
const TSourceLoc &location); const TSourceLoc &location);
TIntermAggregate *addFunctionDefinition(const TFunction &function, TIntermFunctionDefinition *addFunctionDefinition(const TFunction &function,
TIntermAggregate *functionPrototype, TIntermAggregate *functionParameters,
TIntermBlock *functionBody, TIntermBlock *functionBody,
const TSourceLoc &location); const TSourceLoc &location);
void parseFunctionDefinitionHeader(const TSourceLoc &location, void parseFunctionDefinitionHeader(const TSourceLoc &location,
TFunction **function, TFunction **function,
TIntermAggregate **aggregateOut); TIntermAggregate **aggregateOut);
......
...@@ -172,7 +172,7 @@ TType GetFieldType(const TType &indexedType) ...@@ -172,7 +172,7 @@ TType GetFieldType(const TType &indexedType)
// base[1] = value; // base[1] = value;
// } // }
// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write) TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write)
{ {
ASSERT(!type.isArray()); ASSERT(!type.isArray());
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
...@@ -180,8 +180,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write) ...@@ -180,8 +180,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends. // principle this code could be used with multiple backends.
type.setPrecision(EbpHigh); type.setPrecision(EbpHigh);
TIntermAggregate *indexingFunction = new TIntermAggregate(EOpFunction);
indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
TType fieldType = GetFieldType(type); TType fieldType = GetFieldType(type);
int numCases = 0; int numCases = 0;
...@@ -193,14 +191,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write) ...@@ -193,14 +191,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
{ {
numCases = type.getNominalSize(); numCases = type.getNominalSize();
} }
if (write)
{
indexingFunction->setType(TType(EbtVoid));
}
else
{
indexingFunction->setType(fieldType);
}
TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters); TIntermAggregate *paramsNode = new TIntermAggregate(EOpParameters);
TQualifier baseQualifier = EvqInOut; TQualifier baseQualifier = EvqInOut;
...@@ -215,7 +205,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write) ...@@ -215,7 +205,6 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
TIntermSymbol *valueParam = CreateValueSymbol(fieldType); TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
paramsNode->getSequence()->push_back(valueParam); paramsNode->getSequence()->push_back(valueParam);
} }
indexingFunction->getSequence()->push_back(paramsNode);
TIntermBlock *statementList = new TIntermBlock(); TIntermBlock *statementList = new TIntermBlock();
for (int i = 0; i < numCases; ++i) for (int i = 0; i < numCases; ++i)
...@@ -284,8 +273,16 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write) ...@@ -284,8 +273,16 @@ TIntermAggregate *GetIndexFunctionDefinition(TType type, bool write)
bodyNode->getSequence()->push_back(ifNode); bodyNode->getSequence()->push_back(ifNode);
bodyNode->getSequence()->push_back(useLastBlock); bodyNode->getSequence()->push_back(useLastBlock);
indexingFunction->getSequence()->push_back(bodyNode); TIntermFunctionDefinition *indexingFunction = nullptr;
if (write)
{
indexingFunction = new TIntermFunctionDefinition(TType(EbtVoid), paramsNode, bodyNode);
}
else
{
indexingFunction = new TIntermFunctionDefinition(fieldType, paramsNode, bodyNode);
}
indexingFunction->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
return indexingFunction; return indexingFunction;
} }
......
...@@ -25,7 +25,7 @@ class ElseBlockRewriter : public TIntermTraverser ...@@ -25,7 +25,7 @@ class ElseBlockRewriter : public TIntermTraverser
ElseBlockRewriter(); ElseBlockRewriter();
protected: protected:
bool visitAggregate(Visit visit, TIntermAggregate *aggregate) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *aggregate) override;
bool visitBlock(Visit visit, TIntermBlock *block) override; bool visitBlock(Visit visit, TIntermBlock *block) override;
private: private:
...@@ -39,13 +39,10 @@ ElseBlockRewriter::ElseBlockRewriter() ...@@ -39,13 +39,10 @@ ElseBlockRewriter::ElseBlockRewriter()
mFunctionType(NULL) mFunctionType(NULL)
{} {}
bool ElseBlockRewriter::visitAggregate(Visit visit, TIntermAggregate *node) bool ElseBlockRewriter::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{ {
if (node->getOp() == EOpFunction) // Store the current function context (see comment below)
{ mFunctionType = ((visit == PreVisit) ? &node->getType() : nullptr);
// Store the current function context (see comment below)
mFunctionType = ((visit == PreVisit) ? &node->getType() : nullptr);
}
return true; return true;
} }
......
...@@ -94,7 +94,7 @@ bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggrega ...@@ -94,7 +94,7 @@ bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggrega
// If we're outside a loop condition, we only need to traverse nodes that may contain loops. // If we're outside a loop condition, we only need to traverse nodes that may contain loops.
if (!mInsideLoopConditionOrExpression) if (!mInsideLoopConditionOrExpression)
return (node->getOp() == EOpFunction); return false;
mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode()); mFoundLoopToChange = mConditionsToSimplify.match(node, getParentNode());
return !mFoundLoopToChange; return !mFoundLoopToChange;
......
...@@ -28,7 +28,8 @@ class UseUniformBlockMembers : public TIntermTraverser ...@@ -28,7 +28,8 @@ class UseUniformBlockMembers : public TIntermTraverser
} }
protected: protected:
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override { return !mCodeInserted; }
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
private: private:
void insertUseCode(TIntermSequence *sequence); void insertUseCode(TIntermSequence *sequence);
...@@ -38,31 +39,18 @@ class UseUniformBlockMembers : public TIntermTraverser ...@@ -38,31 +39,18 @@ class UseUniformBlockMembers : public TIntermTraverser
bool mCodeInserted; bool mCodeInserted;
}; };
bool UseUniformBlockMembers::visitAggregate(Visit visit, TIntermAggregate *node) bool UseUniformBlockMembers::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{ {
bool visitChildren = !mCodeInserted; ASSERT(visit == PreVisit);
switch (node->getOp()) if (node->getFunctionSymbolInfo()->isMain())
{ {
case EOpFunction: TIntermBlock *body = node->getBody();
{ ASSERT(body);
ASSERT(visit == PreVisit); insertUseCode(body->getSequence());
if (node->getFunctionSymbolInfo()->isMain()) mCodeInserted = true;
{ return false;
TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 2);
TIntermBlock *body = (*sequence)[1]->getAsBlock();
ASSERT(body);
insertUseCode(body->getSequence());
mCodeInserted = true;
visitChildren = false;
}
break;
}
default:
visitChildren = false;
break;
} }
return visitChildren; return !mCodeInserted;
} }
void UseUniformBlockMembers::AddFieldUseStatements(const ShaderVariable &var, void UseUniformBlockMembers::AddFieldUseStatements(const ShaderVariable &var,
......
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
namespace namespace
{ {
void OutputFunction(TInfoSinkBase &out, const char *str, TIntermAggregate *node) void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *info)
{ {
const char *internal = const char *internal = info->getNameObj().isInternal() ? " (internal function)" : "";
node->getFunctionSymbolInfo()->getNameObj().isInternal() ? " (internal function)" : ""; out << str << internal << ": " << info->getNameObj().getString() << " (symbol id "
out << str << internal << ": " << node->getFunctionSymbolInfo()->getNameObj().getString(); << info->getId() << ")";
} }
// //
...@@ -48,6 +48,7 @@ class TOutputTraverser : public TIntermTraverser ...@@ -48,6 +48,7 @@ class TOutputTraverser : public TIntermTraverser
bool visitUnary(Visit visit, TIntermUnary *) override; bool visitUnary(Visit visit, TIntermUnary *) override;
bool visitTernary(Visit visit, TIntermTernary *node) override; bool visitTernary(Visit visit, TIntermTernary *node) override;
bool visitIfElse(Visit visit, TIntermIfElse *node) override; bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override; bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitBlock(Visit visit, TIntermBlock *) override; bool visitBlock(Visit visit, TIntermBlock *) override;
bool visitLoop(Visit visit, TIntermLoop *) override; bool visitLoop(Visit visit, TIntermLoop *) override;
...@@ -372,6 +373,15 @@ bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node) ...@@ -372,6 +373,15 @@ bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node)
return true; return true;
} }
bool TOutputTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
TInfoSinkBase &out = sink;
OutputTreeText(out, node, mDepth);
OutputFunction(out, "Function Definition", node->getFunctionSymbolInfo());
out << "\n";
return true;
}
bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
TInfoSinkBase &out = sink; TInfoSinkBase &out = sink;
...@@ -389,10 +399,13 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -389,10 +399,13 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
switch (node->getOp()) switch (node->getOp())
{ {
case EOpComma: out << "Comma\n"; return true; case EOpComma: out << "Comma\n"; return true;
case EOpFunction: OutputFunction(out, "Function Definition", node); break; case EOpFunctionCall:
case EOpFunctionCall: OutputFunction(out, "Function Call", node); break; OutputFunction(out, "Function Call", node->getFunctionSymbolInfo());
break;
case EOpParameters: out << "Function Parameters: "; break; case EOpParameters: out << "Function Parameters: "; break;
case EOpPrototype: OutputFunction(out, "Function Prototype", node); break; case EOpPrototype:
OutputFunction(out, "Function Prototype", node->getFunctionSymbolInfo());
break;
case EOpConstructFloat: out << "Construct float"; break; case EOpConstructFloat: out << "Construct float"; break;
case EOpConstructVec2: out << "Construct vec2"; break; case EOpConstructVec2: out << "Construct vec2"; break;
......
...@@ -39,7 +39,7 @@ class TypeTrackingTest : public testing::Test ...@@ -39,7 +39,7 @@ class TypeTrackingTest : public testing::Test
const char *shaderStrings[] = { shaderString.c_str() }; const char *shaderStrings[] = { shaderString.c_str() };
bool compilationSuccess = mTranslator->compile(shaderStrings, 1, SH_INTERMEDIATE_TREE); bool compilationSuccess = mTranslator->compile(shaderStrings, 1, SH_INTERMEDIATE_TREE);
TInfoSink &infoSink = mTranslator->getInfoSink(); TInfoSink &infoSink = mTranslator->getInfoSink();
mInfoLog = infoSink.info.c_str(); mInfoLog = RemoveSymbolIdsFromInfoLog(infoSink.info.c_str());
if (!compilationSuccess) if (!compilationSuccess)
FAIL() << "Shader compilation failed " << mInfoLog; FAIL() << "Shader compilation failed " << mInfoLog;
} }
...@@ -55,6 +55,23 @@ class TypeTrackingTest : public testing::Test ...@@ -55,6 +55,23 @@ class TypeTrackingTest : public testing::Test
} }
private: private:
// Remove symbol ids from info log - the tests don't care about them.
static std::string RemoveSymbolIdsFromInfoLog(const char *infoLog)
{
std::string filteredLog(infoLog);
size_t idPrefixPos = 0u;
do
{
idPrefixPos = filteredLog.find(" (symbol id");
if (idPrefixPos != std::string::npos)
{
size_t idSuffixPos = filteredLog.find(")", idPrefixPos);
filteredLog.erase(idPrefixPos, idSuffixPos - idPrefixPos + 1u);
}
} while (idPrefixPos != std::string::npos);
return filteredLog;
}
TranslatorESSL *mTranslator; TranslatorESSL *mTranslator;
std::string mInfoLog; std::string mInfoLog;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment