Commit beb6dc74 by Olli Etuaho Committed by Commit Bot

Always use TFunction instead of TFunctionSymbolInfo

This reduces unnecessary memory allocations and conversions between different objects containing the same data. BUG=angleproject:2267 TEST=angle_unittests Change-Id: I87316509ab1cd6d36756ff6af7fa2b5c5a76a8ea Reviewed-on: https://chromium-review.googlesource.com/827134Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 1bb8528c
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "compiler/translator/IntermNode_util.h" #include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
#include "compiler/translator/StaticType.h"
#include "compiler/translator/SymbolTable.h" #include "compiler/translator/SymbolTable.h"
namespace sh namespace sh
...@@ -29,24 +30,6 @@ void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to) ...@@ -29,24 +30,6 @@ void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to)
} }
} }
TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall,
TIntermTyped *returnValueTarget)
{
TIntermSequence *replacementArguments = new TIntermSequence();
TIntermSequence *originalArguments = originalCall->getSequence();
for (auto &arg : *originalArguments)
{
replacementArguments->push_back(arg);
}
replacementArguments->push_back(returnValueTarget);
ASSERT(originalCall->getFunction());
TIntermAggregate *replacementCall =
TIntermAggregate::CreateFunctionCall(*originalCall->getFunction(), replacementArguments);
replacementCall->setType(TType(EbtVoid));
replacementCall->setLine(originalCall->getLine());
return replacementCall;
}
class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
{ {
public: public:
...@@ -61,15 +44,43 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser ...@@ -61,15 +44,43 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
bool visitBranch(Visit visit, TIntermBranch *node) override; bool visitBranch(Visit visit, TIntermBranch *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
TIntermTyped *returnValueTarget);
// Set when traversal is inside a function with array return value. // Set when traversal is inside a function with array return value.
TIntermFunctionDefinition *mFunctionWithArrayReturnValue; TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
// Map from function symbol ids to array return value variables. struct ChangedFunction
std::map<int, TVariable *> mReturnValueVariables; {
const TVariable *returnValueVariable;
const TFunction *func;
};
// Map from function symbol ids to the changed function.
std::map<int, ChangedFunction> mChangedFunctions;
const TString *const mReturnValueVariableName; const TString *const mReturnValueVariableName;
}; };
TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
TIntermAggregate *originalCall,
TIntermTyped *returnValueTarget)
{
TIntermSequence *replacementArguments = new TIntermSequence();
TIntermSequence *originalArguments = originalCall->getSequence();
for (auto &arg : *originalArguments)
{
replacementArguments->push_back(arg);
}
replacementArguments->push_back(returnValueTarget);
ASSERT(originalCall->getFunction());
const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
*mChangedFunctions[originalId.get()].func, replacementArguments);
replacementCall->setLine(originalCall->getLine());
return replacementCall;
}
void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable) void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable)
{ {
ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable); ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
...@@ -108,21 +119,25 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit ...@@ -108,21 +119,25 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit
{ {
// Replace the whole prototype node with another node that has the out parameter // Replace the whole prototype node with another node that has the out parameter
// added. Also set the function to return void. // added. Also set the function to return void.
TIntermFunctionPrototype *replacement = const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
new TIntermFunctionPrototype(TType(EbtVoid), node->getFunctionSymbolInfo()->getId()); if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
CopyAggregateChildren(node, replacement);
const TSymbolUniqueId &functionId = node->getFunctionSymbolInfo()->getId();
if (mReturnValueVariables.find(functionId.get()) == mReturnValueVariables.end())
{ {
TType returnValueVariableType(node->getType()); TType returnValueVariableType(node->getType());
returnValueVariableType.setQualifier(EvqOut); returnValueVariableType.setQualifier(EvqOut);
mReturnValueVariables[functionId.get()] = ChangedFunction changedFunction;
changedFunction.returnValueVariable =
new TVariable(mSymbolTable, mReturnValueVariableName, returnValueVariableType, new TVariable(mSymbolTable, mReturnValueVariableName, returnValueVariableType,
SymbolType::AngleInternal); SymbolType::AngleInternal);
changedFunction.func = new TFunction(mSymbolTable, node->getFunction()->name(),
StaticType::GetBasic<EbtVoid>(),
node->getFunction()->symbolType(), false);
mChangedFunctions[functionId.get()] = changedFunction;
} }
TIntermFunctionPrototype *replacement =
new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
CopyAggregateChildren(node, replacement);
replacement->getSequence()->push_back( replacement->getSequence()->push_back(
new TIntermSymbol(mReturnValueVariables[functionId.get()])); new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable));
*replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
replacement->setLine(node->getLine()); replacement->setLine(node->getLine());
queueReplacement(replacement, OriginalNode::IS_DROPPED); queueReplacement(replacement, OriginalNode::IS_DROPPED);
...@@ -161,7 +176,7 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter ...@@ -161,7 +176,7 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter
// f(s0); // f(s0);
TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue); TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
replacements.push_back(CreateReplacementCall(node, returnValueSymbol)); replacements.push_back(createReplacementCall(node, returnValueSymbol));
mMultiReplacements.push_back( mMultiReplacements.push_back(
NodeReplaceWithMultipleEntry(parentBlock, node, replacements)); NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
} }
...@@ -180,10 +195,10 @@ bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBr ...@@ -180,10 +195,10 @@ bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBr
TIntermTyped *expression = node->getExpression(); TIntermTyped *expression = node->getExpression();
ASSERT(expression != nullptr); ASSERT(expression != nullptr);
const TSymbolUniqueId &functionId = const TSymbolUniqueId &functionId =
mFunctionWithArrayReturnValue->getFunctionSymbolInfo()->getId(); mFunctionWithArrayReturnValue->getFunction()->uniqueId();
ASSERT(mReturnValueVariables.find(functionId.get()) != mReturnValueVariables.end()); ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
TIntermSymbol *returnValueSymbol = TIntermSymbol *returnValueSymbol =
new TIntermSymbol(mReturnValueVariables[functionId.get()]); new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
TIntermBinary *replacementAssignment = TIntermBinary *replacementAssignment =
new TIntermBinary(EOpAssign, returnValueSymbol, expression); new TIntermBinary(EOpAssign, returnValueSymbol, expression);
replacementAssignment->setLine(expression->getLine()); replacementAssignment->setLine(expression->getLine());
...@@ -207,7 +222,7 @@ bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBi ...@@ -207,7 +222,7 @@ bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBi
ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction); ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST) if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
{ {
TIntermAggregate *replacementCall = CreateReplacementCall(rightAgg, node->getLeft()); TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
queueReplacement(replacementCall, OriginalNode::IS_DROPPED); queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
} }
} }
......
...@@ -81,8 +81,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -81,8 +81,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
record.callees.push_back(static_cast<int>(callee->index)); record.callees.push_back(static_cast<int>(callee->index));
} }
(*idToIndex)[data.node->getFunctionSymbolInfo()->getId().get()] = (*idToIndex)[data.node->getFunction()->uniqueId().get()] = static_cast<int>(data.index);
static_cast<int>(data.index);
} }
} }
...@@ -104,17 +103,17 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -104,17 +103,17 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
// Create the record if need be and remember the node. // Create the record if need be and remember the node.
if (visit == PreVisit) if (visit == PreVisit)
{ {
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getId().get()); auto it = mFunctions.find(node->getFunction()->uniqueId().get());
if (it == mFunctions.end()) if (it == mFunctions.end())
{ {
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getId().get()]; mCurrentFunction = &mFunctions[node->getFunction()->uniqueId().get()];
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName(); mCurrentFunction->name = *node->getFunction()->name();
} }
else else
{ {
mCurrentFunction = &it->second; mCurrentFunction = &it->second;
ASSERT(mCurrentFunction->name == node->getFunctionSymbolInfo()->getName()); ASSERT(mCurrentFunction->name == *node->getFunction()->name());
} }
mCurrentFunction->node = node; mCurrentFunction->node = node;
...@@ -135,8 +134,8 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -135,8 +134,8 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
} }
// Function declaration, create an empty record. // Function declaration, create an empty record.
auto &record = mFunctions[node->getFunctionSymbolInfo()->getId().get()]; auto &record = mFunctions[node->getFunction()->uniqueId().get()];
record.name = node->getFunctionSymbolInfo()->getName(); record.name = *node->getFunction()->name();
// No need to traverse the parameters. // No need to traverse the parameters.
return false; return false;
......
...@@ -993,22 +993,22 @@ class TCompiler::UnusedPredicate ...@@ -993,22 +993,22 @@ class TCompiler::UnusedPredicate
const TIntermFunctionPrototype *asFunctionPrototype = node->getAsFunctionPrototypeNode(); const TIntermFunctionPrototype *asFunctionPrototype = node->getAsFunctionPrototypeNode();
const TIntermFunctionDefinition *asFunctionDefinition = node->getAsFunctionDefinition(); const TIntermFunctionDefinition *asFunctionDefinition = node->getAsFunctionDefinition();
const TFunctionSymbolInfo *functionInfo = nullptr; const TFunction *func = nullptr;
if (asFunctionDefinition) if (asFunctionDefinition)
{ {
functionInfo = asFunctionDefinition->getFunctionSymbolInfo(); func = asFunctionDefinition->getFunction();
} }
else if (asFunctionPrototype) else if (asFunctionPrototype)
{ {
functionInfo = asFunctionPrototype->getFunctionSymbolInfo(); func = asFunctionPrototype->getFunction();
} }
if (functionInfo == nullptr) if (func == nullptr)
{ {
return false; return false;
} }
size_t callDagIndex = mCallDag->findIndex(functionInfo->getId()); size_t callDagIndex = mCallDag->findIndex(func->uniqueId());
if (callDagIndex == CallDAG::InvalidIndex) if (callDagIndex == CallDAG::InvalidIndex)
{ {
// This happens only for unimplemented prototypes which are thus unused // This happens only for unimplemented prototypes which are thus unused
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "compiler/translator/FindMain.h" #include "compiler/translator/FindMain.h"
#include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNode.h"
#include "compiler/translator/Symbol.h"
namespace sh namespace sh
{ {
...@@ -18,7 +19,7 @@ TIntermFunctionDefinition *FindMain(TIntermBlock *root) ...@@ -18,7 +19,7 @@ TIntermFunctionDefinition *FindMain(TIntermBlock *root)
for (TIntermNode *node : *root->getSequence()) for (TIntermNode *node : *root->getSequence())
{ {
TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition(); TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition();
if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain()) if (nodeFunction != nullptr && nodeFunction->getFunction()->isMain())
{ {
return nodeFunction; return nodeFunction;
} }
......
...@@ -595,49 +595,10 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T ...@@ -595,49 +595,10 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T
mUnionArrayPointer = node.mUnionArrayPointer; mUnionArrayPointer = node.mUnionArrayPointer;
} }
void TFunctionSymbolInfo::setFromFunction(const TFunction &function) TIntermFunctionPrototype::TIntermFunctionPrototype(const TFunction *function)
: TIntermTyped(function->getReturnType()), mFunction(function)
{ {
mName.setString(*function.name()); ASSERT(mFunction->symbolType() != SymbolType::Empty);
mName.setInternal(function.symbolType() == SymbolType::AngleInternal);
setId(TSymbolUniqueId(function));
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TSymbolUniqueId &id) : mId(new TSymbolUniqueId(id))
{
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TFunctionSymbolInfo &info)
: mName(info.mName), mId(nullptr)
{
if (info.mId)
{
mId = new TSymbolUniqueId(*info.mId);
}
}
TFunctionSymbolInfo &TFunctionSymbolInfo::operator=(const TFunctionSymbolInfo &info)
{
mName = info.mName;
if (info.mId)
{
mId = new TSymbolUniqueId(*info.mId);
}
else
{
mId = nullptr;
}
return *this;
}
void TFunctionSymbolInfo::setId(const TSymbolUniqueId &id)
{
mId = new TSymbolUniqueId(id);
}
const TSymbolUniqueId &TFunctionSymbolInfo::getId() const
{
ASSERT(mId);
return *mId;
} }
TIntermAggregate::TIntermAggregate(const TIntermAggregate &node) TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
......
...@@ -522,32 +522,6 @@ class TIntermUnary : public TIntermOperator ...@@ -522,32 +522,6 @@ class TIntermUnary : public TIntermOperator
TIntermUnary(const TIntermUnary &node); // note: not deleted, just private! TIntermUnary(const TIntermUnary &node); // note: not deleted, just private!
}; };
class TFunctionSymbolInfo
{
public:
POOL_ALLOCATOR_NEW_DELETE();
TFunctionSymbolInfo(const TSymbolUniqueId &id);
TFunctionSymbolInfo() : mId(nullptr) {}
TFunctionSymbolInfo(const TFunctionSymbolInfo &info);
TFunctionSymbolInfo &operator=(const TFunctionSymbolInfo &info);
void setFromFunction(const TFunction &function);
void setNameObj(const TName &name) { mName = name; }
const TName &getNameObj() const { return mName; }
const TString &getName() const { return mName.getString(); }
bool isMain() const { return mName.getString() == "main"; }
void setId(const TSymbolUniqueId &functionId);
const TSymbolUniqueId &getId() const;
private:
TName mName;
TSymbolUniqueId *mId;
};
typedef TVector<TIntermNode *> TIntermSequence; typedef TVector<TIntermNode *> TIntermSequence;
typedef TVector<int> TQualifierList; typedef TVector<int> TQualifierList;
...@@ -688,12 +662,7 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase ...@@ -688,12 +662,7 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase
class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBase class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBase
{ {
public: public:
// TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor TIntermFunctionPrototype(const TFunction *function);
// parameters.
TIntermFunctionPrototype(const TType &type, const TSymbolUniqueId &id)
: TIntermTyped(type), mFunctionInfo(id)
{
}
~TIntermFunctionPrototype() {} ~TIntermFunctionPrototype() {}
TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; } TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; }
...@@ -717,13 +686,12 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas ...@@ -717,13 +686,12 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas
TIntermSequence *getSequence() override { return &mParameters; } TIntermSequence *getSequence() override { return &mParameters; }
const TIntermSequence *getSequence() const override { return &mParameters; } const TIntermSequence *getSequence() const override { return &mParameters; }
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; } const TFunction *getFunction() const { return mFunction; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
protected: protected:
TIntermSequence mParameters; TIntermSequence mParameters;
TFunctionSymbolInfo mFunctionInfo; const TFunction *const mFunction;
}; };
// Node for function definitions. The prototype child node stores the function header including // Node for function definitions. The prototype child node stores the function header including
...@@ -745,10 +713,7 @@ class TIntermFunctionDefinition : public TIntermNode ...@@ -745,10 +713,7 @@ class TIntermFunctionDefinition : public TIntermNode
TIntermFunctionPrototype *getFunctionPrototype() const { return mPrototype; } TIntermFunctionPrototype *getFunctionPrototype() const { return mPrototype; }
TIntermBlock *getBody() const { return mBody; } TIntermBlock *getBody() const { return mBody; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const const TFunction *getFunction() const { return mPrototype->getFunction(); }
{
return mPrototype->getFunctionSymbolInfo();
}
private: private:
TIntermFunctionPrototype *mPrototype; TIntermFunctionPrototype *mPrototype;
......
...@@ -35,16 +35,13 @@ const TFunction *LookUpBuiltInFunction(const TString &name, ...@@ -35,16 +35,13 @@ const TFunction *LookUpBuiltInFunction(const TString &name,
TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TFunction &func) TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TFunction &func)
{ {
TIntermFunctionPrototype *functionNode = return new TIntermFunctionPrototype(&func);
new TIntermFunctionPrototype(func.getReturnType(), func.uniqueId());
functionNode->getFunctionSymbolInfo()->setFromFunction(func);
return functionNode;
} }
TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(const TFunction &func, TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(const TFunction &func,
TIntermBlock *functionBody) TIntermBlock *functionBody)
{ {
return new TIntermFunctionDefinition(CreateInternalFunctionPrototypeNode(func), functionBody); return new TIntermFunctionDefinition(new TIntermFunctionPrototype(&func), functionBody);
} }
TIntermTyped *CreateZeroNode(const TType &type) TIntermTyped *CreateZeroNode(const TType &type)
......
...@@ -675,7 +675,7 @@ TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit, ...@@ -675,7 +675,7 @@ TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit,
void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node) void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{ {
TIntermSequence *sequence = node->getSequence(); TIntermSequence *sequence = node->getSequence();
addToFunctionMap(node->getFunctionSymbolInfo()->getId(), sequence); addToFunctionMap(node->getFunction()->uniqueId(), sequence);
TIntermTraverser::traverseFunctionPrototype(node); TIntermTraverser::traverseFunctionPrototype(node);
} }
......
...@@ -915,7 +915,7 @@ bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototy ...@@ -915,7 +915,7 @@ bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototy
if (type.isArray()) if (type.isArray())
out << ArrayString(type); out << ArrayString(type);
out << " " << hashFunctionNameIfNeeded(*node->getFunctionSymbolInfo()); out << " " << hashFunctionNameIfNeeded(node->getFunction());
out << "("; out << "(";
writeFunctionParameters(*(node->getSequence())); writeFunctionParameters(*(node->getSequence()));
...@@ -1155,18 +1155,6 @@ TString TOutputGLSLBase::hashFunctionNameIfNeeded(const TFunction *func) ...@@ -1155,18 +1155,6 @@ TString TOutputGLSLBase::hashFunctionNameIfNeeded(const TFunction *func)
} }
} }
TString TOutputGLSLBase::hashFunctionNameIfNeeded(const TFunctionSymbolInfo &info)
{
if (info.isMain())
{
return info.getName();
}
else
{
return hashName(info.getNameObj());
}
}
bool TOutputGLSLBase::structDeclared(const TStructure *structure) const bool TOutputGLSLBase::structDeclared(const TStructure *structure) const
{ {
ASSERT(structure); ASSERT(structure);
......
...@@ -73,7 +73,6 @@ class TOutputGLSLBase : public TIntermTraverser ...@@ -73,7 +73,6 @@ class TOutputGLSLBase : public TIntermTraverser
TString hashVariableName(const TName &name); TString hashVariableName(const TName &name);
// Same as hashName(), but without hashing internal functions or "main". // Same as hashName(), but without hashing internal functions or "main".
TString hashFunctionNameIfNeeded(const TFunction *func); TString hashFunctionNameIfNeeded(const TFunction *func);
TString hashFunctionNameIfNeeded(const TFunctionSymbolInfo &info);
// Used to translate function names for differences between ESSL and GLSL // Used to translate function names for differences between ESSL and GLSL
virtual TString translateTextureFunction(const TString &name) { return name; } virtual TString translateTextureFunction(const TString &name) { return name; }
......
...@@ -1723,7 +1723,7 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition ...@@ -1723,7 +1723,7 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition
ASSERT(mCurrentFunctionMetadata == nullptr); ASSERT(mCurrentFunctionMetadata == nullptr);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo()->getId()); size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
ASSERT(index != CallDAG::InvalidIndex); ASSERT(index != CallDAG::InvalidIndex);
mCurrentFunctionMetadata = &mASTMetadataList[index]; mCurrentFunctionMetadata = &mASTMetadataList[index];
...@@ -1731,14 +1731,14 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition ...@@ -1731,14 +1731,14 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition
TIntermSequence *parameters = node->getFunctionPrototype()->getSequence(); TIntermSequence *parameters = node->getFunctionPrototype()->getSequence();
if (node->getFunctionSymbolInfo()->isMain()) if (node->getFunction()->isMain())
{ {
out << "gl_main("; out << "gl_main(";
} }
else else
{ {
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()) out << DecorateFunctionIfNeeded(node->getFunction()) << DisambiguateFunctionName(parameters)
<< DisambiguateFunctionName(parameters) << (mOutputLod0Function ? "Lod0(" : "("); << (mOutputLod0Function ? "Lod0(" : "(");
} }
for (unsigned int i = 0; i < parameters->size(); i++) for (unsigned int i = 0; i < parameters->size(); i++)
...@@ -1772,7 +1772,7 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition ...@@ -1772,7 +1772,7 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition
bool needsLod0 = mASTMetadataList[index].mNeedsLod0; bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER) if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
{ {
ASSERT(!node->getFunctionSymbolInfo()->isMain()); ASSERT(!node->getFunction()->isMain());
mOutputLod0Function = true; mOutputLod0Function = true;
node->traverse(this); node->traverse(this);
mOutputLod0Function = false; mOutputLod0Function = false;
...@@ -1851,7 +1851,7 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n ...@@ -1851,7 +1851,7 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
ASSERT(visit == PreVisit); ASSERT(visit == PreVisit);
size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo()->getId()); size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
// Skip the prototype if it is not implemented (and thus not used) // Skip the prototype if it is not implemented (and thus not used)
if (index == CallDAG::InvalidIndex) if (index == CallDAG::InvalidIndex)
{ {
...@@ -1860,7 +1860,7 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n ...@@ -1860,7 +1860,7 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n
TIntermSequence *arguments = node->getSequence(); TIntermSequence *arguments = node->getSequence();
TString name = DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()); TString name = DecorateFunctionIfNeeded(node->getFunction());
out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments) out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments)
<< (mOutputLod0Function ? "Lod0(" : "("); << (mOutputLod0Function ? "Lod0(" : "(");
...@@ -1914,7 +1914,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1914,7 +1914,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
ASSERT(index != CallDAG::InvalidIndex); ASSERT(index != CallDAG::InvalidIndex);
lod0 &= mASTMetadataList[index].mNeedsLod0; lod0 &= mASTMetadataList[index].mNeedsLod0;
out << DecorateFunctionIfNeeded(TName(node->getFunction())); out << DecorateFunctionIfNeeded(node->getFunction());
out << DisambiguateFunctionName(node->getSequence()); out << DisambiguateFunctionName(node->getSequence());
out << (lod0 ? "Lod0(" : "("); out << (lod0 ? "Lod0(" : "(");
} }
...@@ -1922,7 +1922,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1922,7 +1922,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
// This path is used for internal functions that don't have their definitions in the // This path is used for internal functions that don't have their definitions in the
// AST, such as precision emulation functions. // AST, such as precision emulation functions.
out << DecorateFunctionIfNeeded(TName(node->getFunction())) << "("; out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
} }
else if (node->getFunction()->isImageFunction()) else if (node->getFunction()->isImageFunction())
{ {
......
...@@ -13,13 +13,6 @@ namespace sh ...@@ -13,13 +13,6 @@ namespace sh
namespace namespace
{ {
void OutputFunction(TInfoSinkBase &out, const char *str, const TFunctionSymbolInfo *info)
{
const char *internal = info->getNameObj().isInternal() ? " (internal function)" : "";
out << str << internal << ": " << info->getNameObj().getString() << " (symbol id "
<< info->getId().get() << ")";
}
void OutputFunction(TInfoSinkBase &out, const char *str, const TFunction *func) void OutputFunction(TInfoSinkBase &out, const char *str, const TFunction *func)
{ {
const char *internal = const char *internal =
...@@ -364,7 +357,7 @@ bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDe ...@@ -364,7 +357,7 @@ bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDe
bool TOutputTraverser::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) bool TOutputTraverser::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, mDepth);
OutputFunction(mOut, "Function Prototype", node->getFunctionSymbolInfo()); OutputFunction(mOut, "Function Prototype", node->getFunction());
mOut << " (" << node->getCompleteString() << ")"; mOut << " (" << node->getCompleteString() << ")";
mOut << "\n"; mOut << "\n";
......
...@@ -3156,11 +3156,7 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction( ...@@ -3156,11 +3156,7 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction(
ASSERT(function.name()); ASSERT(function.name());
checkIsNotReserved(location, *function.name()); checkIsNotReserved(location, *function.name());
TIntermFunctionPrototype *prototype = TIntermFunctionPrototype *prototype = new TIntermFunctionPrototype(&function);
new TIntermFunctionPrototype(function.getReturnType(), TSymbolUniqueId(function));
// TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could
// point to the data that already exists in the symbol table.
prototype->getFunctionSymbolInfo()->setFromFunction(function);
prototype->setLine(location); prototype->setLine(location);
for (size_t i = 0; i < function.getParamCount(); i++) for (size_t i = 0; i < function.getParamCount(); i++)
...@@ -3250,7 +3246,7 @@ TIntermFunctionDefinition *TParseContext::addFunctionDefinition( ...@@ -3250,7 +3246,7 @@ TIntermFunctionDefinition *TParseContext::addFunctionDefinition(
if (mCurrentFunctionType->getBasicType() != EbtVoid && !mFunctionReturnsValue) if (mCurrentFunctionType->getBasicType() != EbtVoid && !mFunctionReturnsValue)
{ {
error(location, "function does not return a value:", error(location, "function does not return a value:",
functionPrototype->getFunctionSymbolInfo()->getName().c_str()); functionPrototype->getFunction()->name()->c_str());
} }
if (functionBody == nullptr) if (functionBody == nullptr)
......
...@@ -77,9 +77,7 @@ void WrapMainAndAppend(TIntermBlock *root, ...@@ -77,9 +77,7 @@ void WrapMainAndAppend(TIntermBlock *root,
// void main() // void main()
TFunction *newMain = new TFunction(symbolTable, NewPoolTString("main"), new TType(EbtVoid), TFunction *newMain = new TFunction(symbolTable, NewPoolTString("main"), new TType(EbtVoid),
SymbolType::UserDefined, false); SymbolType::UserDefined, false);
TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype( TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
TType(EbtVoid), main->getFunctionPrototype()->getFunctionSymbolInfo()->getId());
newMainProto->getFunctionSymbolInfo()->setFromFunction(*newMain);
// { // {
// main0(); // main0();
......
...@@ -730,20 +730,20 @@ TString DecorateVariableIfNeeded(const TName &name) ...@@ -730,20 +730,20 @@ TString DecorateVariableIfNeeded(const TName &name)
} }
} }
TString DecorateFunctionIfNeeded(const TName &name) TString DecorateFunctionIfNeeded(const TFunction *func)
{ {
if (name.isInternal()) if (func->symbolType() == SymbolType::AngleInternal)
{ {
// The name should not have a prefix reserved for user-defined variables or functions. // The name should not have a prefix reserved for user-defined variables or functions.
ASSERT(name.getString().compare(0, 2, "f_") != 0); ASSERT(func->name()->compare(0, 2, "f_") != 0);
ASSERT(name.getString().compare(0, 1, "_") != 0); ASSERT(func->name()->compare(0, 1, "_") != 0);
return name.getString(); return *func->name();
} }
ASSERT(name.getString().compare(0, 3, "gl_") != 0); ASSERT(func->name()->compare(0, 3, "gl_") != 0);
// Add an additional f prefix to functions so that they're always disambiguated from variables. // Add an additional f prefix to functions so that they're always disambiguated from variables.
// This is necessary in the corner case where a variable declaration hides a function that it // This is necessary in the corner case where a variable declaration hides a function that it
// uses in its initializer. // uses in its initializer.
return "f_" + name.getString(); return "f_" + (*func->name());
} }
TString TypeString(const TType &type) TString TypeString(const TType &type)
......
...@@ -111,7 +111,7 @@ TString SamplerString(HLSLTextureGroup type); ...@@ -111,7 +111,7 @@ TString SamplerString(HLSLTextureGroup type);
// Adds a prefix to user-defined names to avoid naming clashes. // Adds a prefix to user-defined names to avoid naming clashes.
TString Decorate(const TString &string); TString Decorate(const TString &string);
TString DecorateVariableIfNeeded(const TName &name); TString DecorateVariableIfNeeded(const TName &name);
TString DecorateFunctionIfNeeded(const TName &name); TString DecorateFunctionIfNeeded(const TFunction *func);
TString DecorateField(const TString &string, const TStructure &structure); TString DecorateField(const TString &string, const TStructure &structure);
TString DecoratePrivate(const TString &privateText); TString DecoratePrivate(const TString &privateText);
TString TypeString(const TType &type); TString TypeString(const TType &type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment