Commit 16c745a3 by Olli Etuaho Committed by Commit Bot

Split TIntermFunctionPrototype from TIntermAggregate

Function prototypes now have their own class TIntermFunctionPrototype. It's only used for prototypes, not function parameter lists. TIntermAggregate is still used for parameter lists and function calls. BUGS=angleproject:1490 TEST=angle_unittests Change-Id: I6e246ad00a29c2335bd2ab7f61cf73fe463b74bb Reviewed-on: https://chromium-review.googlesource.com/427944Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent bddc46b4
...@@ -17,7 +17,7 @@ namespace sh ...@@ -17,7 +17,7 @@ namespace sh
namespace namespace
{ {
void CopyAggregateChildren(TIntermAggregate *from, TIntermAggregate *to) void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to)
{ {
const TIntermSequence *fromSequence = from->getSequence(); const TIntermSequence *fromSequence = from->getSequence();
for (size_t ii = 0; ii < fromSequence->size(); ++ii) for (size_t ii = 0; ii < fromSequence->size(); ++ii)
...@@ -66,6 +66,7 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser ...@@ -66,6 +66,7 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
private: private:
ArrayReturnValueToOutParameterTraverser(); ArrayReturnValueToOutParameterTraverser();
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBranch(Visit visit, TIntermBranch *node) override; bool visitBranch(Visit visit, TIntermBranch *node) override;
...@@ -121,54 +122,49 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition( ...@@ -121,54 +122,49 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
return true; return true;
} }
bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit,
TIntermFunctionPrototype *node)
{
if (visit == PreVisit && node->isArray())
{
// Replace the whole prototype node with another node that has the out parameter
// added.
TIntermFunctionPrototype *replacement = new TIntermFunctionPrototype(TType(EbtVoid));
CopyAggregateChildren(node, replacement);
replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
*replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
replacement->setLine(node->getLine());
queueReplacement(node, replacement, OriginalNode::IS_DROPPED);
}
return false;
}
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
if (visit == PreVisit) if (visit == PreVisit && node->isArray() && node->getOp() == EOpFunctionCall)
{ {
if (node->isArray()) // Handle call sites where the returned array is not assigned.
// Examples where f() is a function returning an array:
// 1. f();
// 2. another_array == f();
// 3. another_function(f());
// 4. return f();
// Cases 2 to 4 are already converted to simpler cases by
// SeparateExpressionsReturningArrays, so we only need to worry about the case where a
// function call returning an array forms an expression by itself.
TIntermBlock *parentBlock = getParentNode()->getAsBlock();
if (parentBlock)
{ {
if (node->getOp() == EOpPrototype) nextTemporaryIndex();
{ TIntermSequence replacements;
// Replace the whole prototype node with another node that has the out parameter replacements.push_back(createTempDeclaration(node->getType()));
// added. TIntermSymbol *returnSymbol = createTempSymbol(node->getType());
TIntermAggregate *replacement = new TIntermAggregate; replacements.push_back(CreateReplacementCall(node, returnSymbol));
replacement->setOp(EOpPrototype); mMultiReplacements.push_back(
CopyAggregateChildren(node, replacement); NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
replacement->setUserDefined();
*replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
replacement->setLine(node->getLine());
replacement->setType(TType(EbtVoid));
queueReplacement(node, replacement, OriginalNode::IS_DROPPED);
}
else if (node->getOp() == EOpFunctionCall)
{
// Handle call sites where the returned array is not assigned.
// Examples where f() is a function returning an array:
// 1. f();
// 2. another_array == f();
// 3. another_function(f());
// 4. return f();
// Cases 2 to 4 are already converted to simpler cases by
// SeparateExpressionsReturningArrays, so we
// only need to worry about the case where a function call returning an array forms
// an expression by
// itself.
TIntermBlock *parentBlock = getParentNode()->getAsBlock();
if (parentBlock)
{
nextTemporaryIndex();
TIntermSequence replacements;
replacements.push_back(createTempDeclaration(node->getType()));
TIntermSymbol *returnSymbol = createTempSymbol(node->getType());
replacements.push_back(CreateReplacementCall(node, returnSymbol));
mMultiReplacements.push_back(
NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
}
return false;
}
} }
return false;
} }
return true; return true;
} }
......
...@@ -122,19 +122,22 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -122,19 +122,22 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
return true; return true;
} }
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override
{
ASSERT(visit == PreVisit);
// Function declaration, create an empty record.
auto &record = mFunctions[node->getFunctionSymbolInfo()->getName()];
record.name = node->getFunctionSymbolInfo()->getName();
// No need to traverse the parameters.
return false;
}
// Aggregates the AST node for each function as well as the name of the functions called by it // Aggregates the AST node for each function as well as the name of the functions called by it
bool visitAggregate(Visit visit, TIntermAggregate *node) override bool visitAggregate(Visit visit, TIntermAggregate *node) override
{ {
switch (node->getOp()) switch (node->getOp())
{ {
case EOpPrototype:
if (visit == PreVisit)
{
// Function declaration, create an empty record.
auto &record = mFunctions[node->getFunctionSymbolInfo()->getName()];
record.name = node->getFunctionSymbolInfo()->getName();
}
break;
case EOpFunctionCall: case EOpFunctionCall:
{ {
// Function call, add the callees // Function call, add the callees
......
...@@ -791,21 +791,18 @@ class TCompiler::UnusedPredicate ...@@ -791,21 +791,18 @@ class TCompiler::UnusedPredicate
bool operator()(TIntermNode *node) bool operator()(TIntermNode *node)
{ {
const TIntermAggregate *asAggregate = node->getAsAggregate(); const TIntermFunctionPrototype *asFunctionPrototype = node->getAsFunctionPrototypeNode();
const TIntermFunctionDefinition *asFunction = node->getAsFunctionDefinition(); const TIntermFunctionDefinition *asFunctionDefinition = node->getAsFunctionDefinition();
const TFunctionSymbolInfo *functionInfo = nullptr; const TFunctionSymbolInfo *functionInfo = nullptr;
if (asFunction) if (asFunctionDefinition)
{ {
functionInfo = asFunction->getFunctionSymbolInfo(); functionInfo = asFunctionDefinition->getFunctionSymbolInfo();
} }
else if (asAggregate) else if (asFunctionPrototype)
{ {
if (asAggregate->getOp() == EOpPrototype) functionInfo = asFunctionPrototype->getFunctionSymbolInfo();
{
functionInfo = asAggregate->getFunctionSymbolInfo();
}
} }
if (functionInfo == nullptr) if (functionInfo == nullptr)
{ {
...@@ -816,7 +813,7 @@ class TCompiler::UnusedPredicate ...@@ -816,7 +813,7 @@ class TCompiler::UnusedPredicate
if (callDagIndex == CallDAG::InvalidIndex) if (callDagIndex == CallDAG::InvalidIndex)
{ {
// This happens only for unimplemented prototypes which are thus unused // This happens only for unimplemented prototypes which are thus unused
ASSERT(asAggregate && asAggregate->getOp() == EOpPrototype); ASSERT(asFunctionPrototype);
return true; return true;
} }
......
...@@ -30,13 +30,12 @@ void SetInternalFunctionName(TFunctionSymbolInfo *functionInfo, const char *name ...@@ -30,13 +30,12 @@ void SetInternalFunctionName(TFunctionSymbolInfo *functionInfo, const char *name
functionInfo->setNameObj(nameObj); functionInfo->setNameObj(nameObj);
} }
TIntermAggregate *CreateFunctionPrototypeNode(const char *name, const int functionId) TIntermFunctionPrototype *CreateFunctionPrototypeNode(const char *name, const int functionId)
{ {
TIntermAggregate *functionNode = new TIntermAggregate(EOpPrototype); TType returnType(EbtVoid);
TIntermFunctionPrototype *functionNode = new TIntermFunctionPrototype(returnType);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name); SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
TType returnType(EbtVoid);
functionNode->setType(returnType);
functionNode->getFunctionSymbolInfo()->setId(functionId); functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode; return functionNode;
} }
...@@ -146,7 +145,7 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -146,7 +145,7 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
const char *functionName = "initializeDeferredGlobals"; const char *functionName = "initializeDeferredGlobals";
// Add function prototype to the beginning of the shader // Add function prototype to the beginning of the shader
TIntermAggregate *functionPrototypeNode = TIntermFunctionPrototype *functionPrototypeNode =
CreateFunctionPrototypeNode(functionName, initFunctionId); CreateFunctionPrototypeNode(functionName, initFunctionId);
root->getSequence()->insert(root->getSequence()->begin(), functionPrototypeNode); root->getSequence()->insert(root->getSequence()->begin(), functionPrototypeNode);
......
...@@ -628,6 +628,11 @@ bool EmulatePrecision::visitInvariantDeclaration(Visit visit, TIntermInvariantDe ...@@ -628,6 +628,11 @@ bool EmulatePrecision::visitInvariantDeclaration(Visit visit, TIntermInvariantDe
return false; return false;
} }
bool EmulatePrecision::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{
return false;
}
bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node) bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
bool visitChildren = true; bool visitChildren = true;
...@@ -635,9 +640,6 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -635,9 +640,6 @@ bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
case EOpConstructStruct: case EOpConstructStruct:
break; break;
case EOpPrototype:
visitChildren = false;
break;
case EOpParameters: case EOpParameters:
visitChildren = false; visitChildren = false;
break; break;
......
...@@ -32,6 +32,7 @@ class EmulatePrecision : public TLValueTrackingTraverser ...@@ -32,6 +32,7 @@ class EmulatePrecision : public TLValueTrackingTraverser
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node) override; bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node) override;
bool visitDeclaration(Visit visit, TIntermDeclaration *node) override; bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void writeEmulationHelpers(TInfoSinkBase &sink, void writeEmulationHelpers(TInfoSinkBase &sink,
const int shaderVersion, const int shaderVersion,
......
...@@ -226,6 +226,11 @@ bool TIntermBlock::replaceChildNode(TIntermNode *original, TIntermNode *replacem ...@@ -226,6 +226,11 @@ bool TIntermBlock::replaceChildNode(TIntermNode *original, TIntermNode *replacem
return replaceChildNodeInternal(original, replacement); return replaceChildNodeInternal(original, replacement);
} }
bool TIntermFunctionPrototype::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
return replaceChildNodeInternal(original, replacement);
}
bool TIntermDeclaration::replaceChildNode(TIntermNode *original, TIntermNode *replacement) bool TIntermDeclaration::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{ {
return replaceChildNodeInternal(original, replacement); return replaceChildNodeInternal(original, replacement);
...@@ -338,6 +343,12 @@ void TIntermBlock::appendStatement(TIntermNode *statement) ...@@ -338,6 +343,12 @@ void TIntermBlock::appendStatement(TIntermNode *statement)
} }
} }
void TIntermFunctionPrototype::appendParameter(TIntermSymbol *parameter)
{
ASSERT(parameter != nullptr);
mParameters.push_back(parameter);
}
void TIntermDeclaration::appendDeclarator(TIntermTyped *declarator) void TIntermDeclaration::appendDeclarator(TIntermTyped *declarator)
{ {
ASSERT(declarator != nullptr); ASSERT(declarator != nullptr);
......
...@@ -37,6 +37,7 @@ class TIntermAggregate; ...@@ -37,6 +37,7 @@ class TIntermAggregate;
class TIntermBlock; class TIntermBlock;
class TIntermInvariantDeclaration; class TIntermInvariantDeclaration;
class TIntermDeclaration; class TIntermDeclaration;
class TIntermFunctionPrototype;
class TIntermFunctionDefinition; class TIntermFunctionDefinition;
class TIntermSwizzle; class TIntermSwizzle;
class TIntermBinary; class TIntermBinary;
...@@ -103,6 +104,7 @@ class TIntermNode : angle::NonCopyable ...@@ -103,6 +104,7 @@ class TIntermNode : angle::NonCopyable
virtual TIntermFunctionDefinition *getAsFunctionDefinition() { return nullptr; } virtual TIntermFunctionDefinition *getAsFunctionDefinition() { return nullptr; }
virtual TIntermAggregate *getAsAggregate() { return 0; } virtual TIntermAggregate *getAsAggregate() { return 0; }
virtual TIntermBlock *getAsBlock() { return nullptr; } virtual TIntermBlock *getAsBlock() { return nullptr; }
virtual TIntermFunctionPrototype *getAsFunctionPrototypeNode() { return nullptr; }
virtual TIntermDeclaration *getAsDeclarationNode() { return nullptr; } virtual TIntermDeclaration *getAsDeclarationNode() { return nullptr; }
virtual TIntermSwizzle *getAsSwizzleNode() { return nullptr; } virtual TIntermSwizzle *getAsSwizzleNode() { return nullptr; }
virtual TIntermBinary *getAsBinaryNode() { return 0; } virtual TIntermBinary *getAsBinaryNode() { return 0; }
...@@ -704,6 +706,43 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase ...@@ -704,6 +706,43 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase
TIntermSequence mStatements; TIntermSequence mStatements;
}; };
// Function prototype declaration. The type of the node is the function return type.
class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBase
{
public:
TIntermFunctionPrototype(const TType &type) : TIntermTyped(type) {}
~TIntermFunctionPrototype() {}
TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; }
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *deepCopy() const override
{
UNREACHABLE();
return nullptr;
}
bool hasSideEffects() const override
{
UNREACHABLE();
return true;
}
// Only intended for initially building the declaration.
void appendParameter(TIntermSymbol *parameter);
TIntermSequence *getSequence() override { return &mParameters; }
const TIntermSequence *getSequence() const override { return &mParameters; }
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
protected:
TIntermSequence mParameters;
TFunctionSymbolInfo mFunctionInfo;
};
// Struct, interface block or variable declaration. Can contain multiple variable declarators. // Struct, interface block or variable declaration. Can contain multiple variable declarators.
class TIntermDeclaration : public TIntermNode, public TIntermAggregateBase class TIntermDeclaration : public TIntermNode, public TIntermAggregateBase
{ {
...@@ -877,6 +916,10 @@ class TIntermTraverser : angle::NonCopyable ...@@ -877,6 +916,10 @@ class TIntermTraverser : angle::NonCopyable
virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; } virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; } virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
virtual bool visitCase(Visit visit, TIntermCase *node) { return true; } virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
virtual bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{
return true;
}
virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{ {
return true; return true;
...@@ -904,6 +947,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -904,6 +947,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void traverseIfElse(TIntermIfElse *node); virtual void traverseIfElse(TIntermIfElse *node);
virtual void traverseSwitch(TIntermSwitch *node); virtual void traverseSwitch(TIntermSwitch *node);
virtual void traverseCase(TIntermCase *node); virtual void traverseCase(TIntermCase *node);
virtual void traverseFunctionPrototype(TIntermFunctionPrototype *node);
virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node); virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
virtual void traverseAggregate(TIntermAggregate *node); virtual void traverseAggregate(TIntermAggregate *node);
virtual void traverseBlock(TIntermBlock *node); virtual void traverseBlock(TIntermBlock *node);
...@@ -1123,6 +1167,7 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -1123,6 +1167,7 @@ class TLValueTrackingTraverser : public TIntermTraverser
void traverseBinary(TIntermBinary *node) final; void traverseBinary(TIntermBinary *node) final;
void traverseUnary(TIntermUnary *node) final; void traverseUnary(TIntermUnary *node) final;
void traverseFunctionPrototype(TIntermFunctionPrototype *node) final;
void traverseFunctionDefinition(TIntermFunctionDefinition *node) final; void traverseFunctionDefinition(TIntermFunctionDefinition *node) final;
void traverseAggregate(TIntermAggregate *node) final; void traverseAggregate(TIntermAggregate *node) final;
......
...@@ -81,6 +81,11 @@ void TIntermDeclaration::traverse(TIntermTraverser *it) ...@@ -81,6 +81,11 @@ void TIntermDeclaration::traverse(TIntermTraverser *it)
it->traverseDeclaration(this); it->traverseDeclaration(this);
} }
void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
{
it->traverseFunctionPrototype(this);
}
void TIntermAggregate::traverse(TIntermTraverser *it) void TIntermAggregate::traverse(TIntermTraverser *it)
{ {
it->traverseAggregate(this); it->traverseAggregate(this);
...@@ -559,6 +564,36 @@ void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node) ...@@ -559,6 +564,36 @@ void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
visitDeclaration(PostVisit, node); visitDeclaration(PostVisit, node);
} }
void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{
bool visit = true;
TIntermSequence *sequence = node->getSequence();
if (preVisit)
visit = visitFunctionPrototype(PreVisit, node);
if (visit)
{
incrementDepth(node);
for (auto *child : *sequence)
{
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitFunctionPrototype(InVisit, node);
}
}
decrementDepth();
}
if (visit && postVisit)
visitFunctionPrototype(PostVisit, node);
}
// Traverse an aggregate node. Same comments in binary node apply here. // Traverse an aggregate node. Same comments in binary node apply here.
void TIntermTraverser::traverseAggregate(TIntermAggregate *node) void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
{ {
...@@ -600,15 +635,19 @@ void TLValueTrackingTraverser::traverseFunctionDefinition(TIntermFunctionDefinit ...@@ -600,15 +635,19 @@ void TLValueTrackingTraverser::traverseFunctionDefinition(TIntermFunctionDefinit
TIntermTraverser::traverseFunctionDefinition(node); TIntermTraverser::traverseFunctionDefinition(node);
} }
void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{
TIntermSequence *sequence = node->getSequence();
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence);
TIntermTraverser::traverseFunctionPrototype(node);
}
void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
{ {
bool visit = true; bool visit = true;
TIntermSequence *sequence = node->getSequence(); TIntermSequence *sequence = node->getSequence();
if (node->getOp() == EOpPrototype)
{
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence);
}
if (preVisit) if (preVisit)
visit = visitAggregate(PreVisit, node); visit = visitAggregate(PreVisit, node);
......
...@@ -10,7 +10,7 @@ const char *GetOperatorString(TOperator op) ...@@ -10,7 +10,7 @@ const char *GetOperatorString(TOperator op)
{ {
switch (op) switch (op)
{ {
// Note: ops from EOpNull to EOpPrototype can't be handled here. // Note: ops from EOpNull to EOpParameters can't be handled here.
case EOpNegative: case EOpNegative:
return "-"; return "-";
......
...@@ -16,8 +16,6 @@ enum TOperator ...@@ -16,8 +16,6 @@ enum TOperator
EOpFunctionCall, EOpFunctionCall,
EOpParameters, // an aggregate listing the parameters to a function EOpParameters, // an aggregate listing the parameters to a function
EOpPrototype,
// //
// Unary operators // Unary operators
// //
......
...@@ -871,30 +871,31 @@ bool TOutputGLSLBase::visitInvariantDeclaration(Visit visit, TIntermInvariantDec ...@@ -871,30 +871,31 @@ bool TOutputGLSLBase::visitInvariantDeclaration(Visit visit, TIntermInvariantDec
return false; return false;
} }
bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{
TInfoSinkBase &out = objSink();
ASSERT(visit == PreVisit);
const TType &type = node->getType();
writeVariableType(type);
if (type.isArray())
out << arrayBrackets(type);
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
out << "(";
writeFunctionParameters(*(node->getSequence()));
out << ")";
return false;
}
bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node) bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
bool visitChildren = true; bool visitChildren = true;
TInfoSinkBase &out = objSink(); TInfoSinkBase &out = objSink();
switch (node->getOp()) switch (node->getOp())
{ {
case EOpPrototype:
// Function declaration.
ASSERT(visit == PreVisit);
{
const TType &type = node->getType();
writeVariableType(type);
if (type.isArray())
out << arrayBrackets(type);
}
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
out << "(";
writeFunctionParameters(*(node->getSequence()));
out << ")";
visitChildren = false;
break;
case EOpFunctionCall: case EOpFunctionCall:
// Function call. // Function call.
if (visit == PreVisit) if (visit == PreVisit)
......
...@@ -57,6 +57,7 @@ class TOutputGLSLBase : public TIntermTraverser ...@@ -57,6 +57,7 @@ class TOutputGLSLBase : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *node) override; bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override; bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override; bool visitCase(Visit visit, TIntermCase *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBlock(Visit visit, TIntermBlock *node) override; bool visitBlock(Visit visit, TIntermBlock *node) override;
......
...@@ -1601,60 +1601,57 @@ bool OutputHLSL::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclarat ...@@ -1601,60 +1601,57 @@ bool OutputHLSL::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclarat
return false; return false;
} }
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
switch (node->getOp()) ASSERT(visit == PreVisit);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
// Skip the prototype if it is not implemented (and thus not used)
if (index == CallDAG::InvalidIndex)
{ {
case EOpPrototype: return false;
if (visit == PreVisit) }
{
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
// Skip the prototype if it is not implemented (and thus not used)
if (index == CallDAG::InvalidIndex)
{
return false;
}
TIntermSequence *arguments = node->getSequence(); TIntermSequence *arguments = node->getSequence();
TString name = TString name = DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()); out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments)
out << TypeString(node->getType()) << " " << name << (mOutputLod0Function ? "Lod0(" : "(");
<< DisambiguateFunctionName(arguments) << (mOutputLod0Function ? "Lod0(" : "(");
for (unsigned int i = 0; i < arguments->size(); i++) for (unsigned int i = 0; i < arguments->size(); i++)
{ {
TIntermSymbol *symbol = (*arguments)[i]->getAsSymbolNode(); TIntermSymbol *symbol = (*arguments)[i]->getAsSymbolNode();
ASSERT(symbol != nullptr);
if (symbol) out << argumentString(symbol);
{
out << argumentString(symbol);
if (i < arguments->size() - 1) if (i < arguments->size() - 1)
{ {
out << ", "; out << ", ";
} }
} }
else
UNREACHABLE();
}
out << ");\n"; out << ");\n";
// Also prototype the Lod0 variant if needed // Also prototype the Lod0 variant if needed
bool needsLod0 = mASTMetadataList[index].mNeedsLod0; bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER) if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
{ {
mOutputLod0Function = true; mOutputLod0Function = true;
node->traverse(this); node->traverse(this);
mOutputLod0Function = false; mOutputLod0Function = false;
} }
return false; return false;
} }
break;
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{
TInfoSinkBase &out = getInfoSink();
switch (node->getOp())
{
case EOpFunctionCall: case EOpFunctionCall:
{ {
TIntermSequence *arguments = node->getSequence(); TIntermSequence *arguments = node->getSequence();
......
...@@ -76,6 +76,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -76,6 +76,7 @@ class OutputHLSL : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *); bool visitIfElse(Visit visit, TIntermIfElse *);
bool visitSwitch(Visit visit, TIntermSwitch *); bool visitSwitch(Visit visit, TIntermSwitch *);
bool visitCase(Visit visit, TIntermCase *); bool visitCase(Visit visit, TIntermCase *);
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *); bool visitAggregate(Visit visit, TIntermAggregate *);
bool visitBlock(Visit visit, TIntermBlock *node); bool visitBlock(Visit visit, TIntermBlock *node);
......
...@@ -2395,8 +2395,9 @@ void TParseContext::parseGlobalLayoutQualifier(const TTypeQualifierBuilder &type ...@@ -2395,8 +2395,9 @@ void TParseContext::parseGlobalLayoutQualifier(const TTypeQualifierBuilder &type
} }
} }
TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction &parsedFunction, TIntermFunctionPrototype *TParseContext::addFunctionPrototypeDeclaration(
const TSourceLoc &location) const TFunction &parsedFunction,
const TSourceLoc &location)
{ {
// Note: function found from the symbol table could be the same as parsedFunction if this is the // Note: function found from the symbol table could be the same as parsedFunction if this is the
// first declaration. Either way the instance in the symbol table is used to track whether the // first declaration. Either way the instance in the symbol table is used to track whether the
...@@ -2411,11 +2412,11 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction ...@@ -2411,11 +2412,11 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction
} }
function->setHasPrototypeDeclaration(); function->setHasPrototypeDeclaration();
TIntermAggregate *prototype = new TIntermAggregate; TIntermFunctionPrototype *prototype = new TIntermFunctionPrototype(function->getReturnType());
// TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could // TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could
// point to the data that already exists in the symbol table. // point to the data that already exists in the symbol table.
prototype->setType(function->getReturnType());
prototype->getFunctionSymbolInfo()->setFromFunction(*function); prototype->getFunctionSymbolInfo()->setFromFunction(*function);
prototype->setLine(location);
for (size_t i = 0; i < function->getParamCount(); i++) for (size_t i = 0; i < function->getParamCount(); i++)
{ {
...@@ -2426,17 +2427,15 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction ...@@ -2426,17 +2427,15 @@ TIntermAggregate *TParseContext::addFunctionPrototypeDeclaration(const TFunction
TIntermSymbol *paramSymbol = intermediate.addSymbol( TIntermSymbol *paramSymbol = intermediate.addSymbol(
variable.getUniqueId(), variable.getName(), variable.getType(), location); variable.getUniqueId(), variable.getName(), variable.getType(), location);
prototype = intermediate.growAggregate(prototype, paramSymbol, location); prototype->appendParameter(paramSymbol);
} }
else else
{ {
TIntermSymbol *paramSymbol = intermediate.addSymbol(0, "", *param.type, location); TIntermSymbol *paramSymbol = intermediate.addSymbol(0, "", *param.type, location);
prototype = intermediate.growAggregate(prototype, paramSymbol, location); prototype->appendParameter(paramSymbol);
} }
} }
prototype->setOp(EOpPrototype);
symbolTable.pop(); symbolTable.pop();
if (!symbolTable.atGlobalLevel()) if (!symbolTable.atGlobalLevel())
......
...@@ -240,8 +240,8 @@ class TParseContext : angle::NonCopyable ...@@ -240,8 +240,8 @@ class TParseContext : angle::NonCopyable
TIntermDeclaration *declarationOut); TIntermDeclaration *declarationOut);
void parseGlobalLayoutQualifier(const TTypeQualifierBuilder &typeQualifierBuilder); void parseGlobalLayoutQualifier(const TTypeQualifierBuilder &typeQualifierBuilder);
TIntermAggregate *addFunctionPrototypeDeclaration(const TFunction &parsedFunction, TIntermFunctionPrototype *addFunctionPrototypeDeclaration(const TFunction &parsedFunction,
const TSourceLoc &location); const TSourceLoc &location);
TIntermFunctionDefinition *addFunctionDefinition(const TFunction &function, TIntermFunctionDefinition *addFunctionDefinition(const TFunction &function,
TIntermAggregate *functionParameters, TIntermAggregate *functionParameters,
TIntermBlock *functionBody, TIntermBlock *functionBody,
......
...@@ -49,6 +49,7 @@ class TOutputTraverser : public TIntermTraverser ...@@ -49,6 +49,7 @@ class TOutputTraverser : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *node) override; bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override; bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override; bool visitCase(Visit visit, TIntermCase *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override; bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitBlock(Visit visit, TIntermBlock *) override; bool visitBlock(Visit visit, TIntermBlock *) override;
...@@ -506,6 +507,18 @@ bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDe ...@@ -506,6 +507,18 @@ bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDe
return true; return true;
} }
bool TOutputTraverser::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{
TInfoSinkBase &out = sink;
OutputTreeText(out, node, mDepth);
OutputFunction(out, "Function Prototype", node->getFunctionSymbolInfo());
out << " (" << node->getCompleteString() << ")";
out << "\n";
return true;
}
bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
TInfoSinkBase &out = sink; TInfoSinkBase &out = sink;
...@@ -527,9 +540,6 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -527,9 +540,6 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpParameters: case EOpParameters:
out << "Function Parameters: "; out << "Function Parameters: ";
break; break;
case EOpPrototype:
OutputFunction(out, "Function Prototype", node->getFunctionSymbolInfo());
break;
case EOpConstructFloat: case EOpConstructFloat:
out << "Construct float"; out << "Construct float";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment