Commit fe48632f by Olli Etuaho Committed by Commit Bot

Prefer identifying functions by using symbol ids

The shader translator code is now structured in a way that ensures that all function definition, function prototype and function call nodes store the integer symbol id for the function. This is guaranteed regardless of whether the function node is added while parsing or as a result of an AST transformation. TIntermAggregate nodes, which include function calls and constructors can now only be created by calling one of the TIntermAggregate::Create*() functions to ensure they have all the necessary properties. This makes it possible to keep track of functions using integer ids instead of their mangled name strings when generating the call graph and when using TLValueTrackingTraverser. This commit includes a few other small cleanups to the CallDAG class as well. BUG=angleproject:1490 TEST=angle_unittests, angle_end2end_tests Change-Id: Idd1013506cbe4c3380e20d90524a9cd09b890259 Reviewed-on: https://chromium-review.googlesource.com/459603Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 76e6565e
...@@ -50,9 +50,9 @@ TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall, ...@@ -50,9 +50,9 @@ TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall,
replacementArguments->push_back(arg); replacementArguments->push_back(arg);
} }
replacementArguments->push_back(returnValueTarget); replacementArguments->push_back(returnValueTarget);
TIntermAggregate *replacementCall = TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
new TIntermAggregate(TType(EbtVoid), EOpCallFunctionInAST, replacementArguments); TType(EbtVoid), originalCall->getFunctionSymbolInfo()->getId(),
*replacementCall->getFunctionSymbolInfo() = *originalCall->getFunctionSymbolInfo(); originalCall->getFunctionSymbolInfo()->getNameObj(), replacementArguments);
replacementCall->setLine(originalCall->getLine()); replacementCall->setLine(originalCall->getLine());
return replacementCall; return replacementCall;
} }
...@@ -110,7 +110,8 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit ...@@ -110,7 +110,8 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit
{ {
// Replace the whole prototype node with another node that has the out parameter // Replace the whole prototype node with another node that has the out parameter
// added. Also set the function to return void. // added. Also set the function to return void.
TIntermFunctionPrototype *replacement = new TIntermFunctionPrototype(TType(EbtVoid)); TIntermFunctionPrototype *replacement =
new TIntermFunctionPrototype(TType(EbtVoid), node->getFunctionSymbolInfo()->getId());
CopyAggregateChildren(node, replacement); CopyAggregateChildren(node, replacement);
replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType())); replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
*replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo(); *replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
// order. // order.
#include "compiler/translator/CallDAG.h" #include "compiler/translator/CallDAG.h"
#include "compiler/translator/Diagnostics.h" #include "compiler/translator/Diagnostics.h"
#include "compiler/translator/SymbolTable.h"
namespace sh namespace sh
{ {
...@@ -78,7 +80,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -78,7 +80,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
record.callees.push_back(static_cast<int>(callee->index)); record.callees.push_back(static_cast<int>(callee->index));
} }
(*idToIndex)[data.node->getFunctionSymbolInfo()->getId()] = (*idToIndex)[data.node->getFunctionSymbolInfo()->getId().get()] =
static_cast<int>(data.index); static_cast<int>(data.index);
} }
} }
...@@ -101,19 +103,20 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -101,19 +103,20 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
// Create the record if need be and remember the node. // Create the record if need be and remember the node.
if (visit == PreVisit) if (visit == PreVisit)
{ {
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName()); auto it = mFunctions.find(node->getFunctionSymbolInfo()->getId().get());
if (it == mFunctions.end()) if (it == mFunctions.end())
{ {
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getName()]; mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getId().get()];
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
} }
else else
{ {
mCurrentFunction = &it->second; mCurrentFunction = &it->second;
ASSERT(mCurrentFunction->name == node->getFunctionSymbolInfo()->getName());
} }
mCurrentFunction->node = node; mCurrentFunction->node = node;
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
} }
else if (visit == PostVisit) else if (visit == PostVisit)
{ {
...@@ -125,8 +128,13 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -125,8 +128,13 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override
{ {
ASSERT(visit == PreVisit); ASSERT(visit == PreVisit);
if (mCurrentFunction != nullptr)
{
return false;
}
// Function declaration, create an empty record. // Function declaration, create an empty record.
auto &record = mFunctions[node->getFunctionSymbolInfo()->getName()]; auto &record = mFunctions[node->getFunctionSymbolInfo()->getId().get()];
record.name = node->getFunctionSymbolInfo()->getName(); record.name = node->getFunctionSymbolInfo()->getName();
// No need to traverse the parameters. // No need to traverse the parameters.
...@@ -139,10 +147,12 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -139,10 +147,12 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST) if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
{ {
// Function call, add the callees // Function call, add the callees
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName()); auto it = mFunctions.find(node->getFunctionSymbolInfo()->getId().get());
ASSERT(it != mFunctions.end()); ASSERT(it != mFunctions.end());
// We might be in a top-level function call to set a global variable // We might be traversing the initializer of a global variable. Even though function
// calls in global scope are forbidden by the parser, some subsequent AST
// transformations can add them to emulate particular features.
if (mCurrentFunction) if (mCurrentFunction)
{ {
mCurrentFunction->callees.insert(&it->second); mCurrentFunction->callees.insert(&it->second);
...@@ -259,7 +269,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -259,7 +269,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
TDiagnostics *mDiagnostics; TDiagnostics *mDiagnostics;
std::map<TString, CreatorFunctionData> mFunctions; std::map<int, CreatorFunctionData> mFunctions;
CreatorFunctionData *mCurrentFunction; CreatorFunctionData *mCurrentFunction;
size_t mCurrentIndex; size_t mCurrentIndex;
}; };
...@@ -278,7 +288,7 @@ const size_t CallDAG::InvalidIndex = std::numeric_limits<size_t>::max(); ...@@ -278,7 +288,7 @@ const size_t CallDAG::InvalidIndex = std::numeric_limits<size_t>::max();
size_t CallDAG::findIndex(const TFunctionSymbolInfo *functionInfo) const size_t CallDAG::findIndex(const TFunctionSymbolInfo *functionInfo) const
{ {
auto it = mFunctionIdToIndex.find(functionInfo->getId()); auto it = mFunctionIdToIndex.find(functionInfo->getId().get());
if (it == mFunctionIdToIndex.end()) if (it == mFunctionIdToIndex.end())
{ {
......
...@@ -21,43 +21,6 @@ namespace sh ...@@ -21,43 +21,6 @@ namespace sh
namespace namespace
{ {
void SetInternalFunctionName(TFunctionSymbolInfo *functionInfo, const char *name)
{
TString nameStr(name);
nameStr = TFunction::mangleName(nameStr);
TName nameObj(nameStr);
nameObj.setInternal(true);
functionInfo->setNameObj(nameObj);
}
TIntermFunctionPrototype *CreateFunctionPrototypeNode(const char *name, const int functionId)
{
TType returnType(EbtVoid);
TIntermFunctionPrototype *functionNode = new TIntermFunctionPrototype(returnType);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode;
}
TIntermFunctionDefinition *CreateFunctionDefinitionNode(const char *name,
TIntermBlock *functionBody,
const int functionId)
{
TIntermFunctionPrototype *prototypeNode = CreateFunctionPrototypeNode(name, functionId);
return new TIntermFunctionDefinition(prototypeNode, functionBody);
}
TIntermAggregate *CreateFunctionCallNode(const char *name, const int functionId)
{
TType returnType(EbtVoid);
TIntermAggregate *functionNode =
new TIntermAggregate(returnType, EOpCallFunctionInAST, nullptr);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode;
}
class DeferGlobalInitializersTraverser : public TIntermTraverser class DeferGlobalInitializersTraverser : public TIntermTraverser
{ {
public: public:
...@@ -132,13 +95,13 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -132,13 +95,13 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
{ {
return; return;
} }
const int initFunctionId = TSymbolTable::nextUniqueId(); TSymbolUniqueId initFunctionId;
const char *functionName = "initializeDeferredGlobals"; const char *functionName = "initializeDeferredGlobals";
// Add function prototype to the beginning of the shader // Add function prototype to the beginning of the shader
TIntermFunctionPrototype *functionPrototypeNode = TIntermFunctionPrototype *functionPrototypeNode =
CreateFunctionPrototypeNode(functionName, initFunctionId); CreateInternalFunctionPrototypeNode(TType(EbtVoid), functionName, initFunctionId);
root->getSequence()->insert(root->getSequence()->begin(), functionPrototypeNode); root->getSequence()->insert(root->getSequence()->begin(), functionPrototypeNode);
// Add function definition to the end of the shader // Add function definition to the end of the shader
...@@ -148,8 +111,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -148,8 +111,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
{ {
functionBody->push_back(deferredInit); functionBody->push_back(deferredInit);
} }
TIntermFunctionDefinition *functionDefinition = TIntermFunctionDefinition *functionDefinition = CreateInternalFunctionDefinitionNode(
CreateFunctionDefinitionNode(functionName, functionBodyNode, initFunctionId); TType(EbtVoid), functionName, functionBodyNode, initFunctionId);
root->getSequence()->push_back(functionDefinition); root->getSequence()->push_back(functionDefinition);
// Insert call into main function // Insert call into main function
...@@ -158,8 +121,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -158,8 +121,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition(); TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition();
if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain()) if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain())
{ {
TIntermAggregate *functionCallNode = TIntermAggregate *functionCallNode = CreateInternalFunctionCallNode(
CreateFunctionCallNode(functionName, initFunctionId); TType(EbtVoid), functionName, initFunctionId, nullptr);
TIntermBlock *mainBody = nodeFunction->getBody(); TIntermBlock *mainBody = nodeFunction->getBody();
ASSERT(mainBody != nullptr); ASSERT(mainBody != nullptr);
......
...@@ -433,7 +433,8 @@ TIntermAggregate *createInternalFunctionCallNode(const TType &type, ...@@ -433,7 +433,8 @@ TIntermAggregate *createInternalFunctionCallNode(const TType &type,
{ {
TName nameObj(TFunction::GetMangledNameFromCall(name, *arguments)); TName nameObj(TFunction::GetMangledNameFromCall(name, *arguments));
nameObj.setInternal(true); nameObj.setInternal(true);
TIntermAggregate *callNode = new TIntermAggregate(type, EOpCallInternalRawFunction, arguments); TIntermAggregate *callNode =
TIntermAggregate::Create(type, EOpCallInternalRawFunction, arguments);
callNode->getFunctionSymbolInfo()->setNameObj(nameObj); callNode->getFunctionSymbolInfo()->setNameObj(nameObj);
return callNode; return callNode;
} }
......
...@@ -272,6 +272,57 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position, ...@@ -272,6 +272,57 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position,
return true; return true;
} }
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode =
new TIntermAggregate(func.getReturnType(), EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(func);
return callNode;
}
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TType &type,
const TSymbolUniqueId &id,
const TName &name,
TIntermSequence *arguments)
{
TIntermAggregate *callNode = new TIntermAggregate(type, EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setId(id);
callNode->getFunctionSymbolInfo()->setNameObj(name);
return callNode;
}
TIntermAggregate *TIntermAggregate::CreateBuiltInFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode =
new TIntermAggregate(func.getReturnType(), EOpCallBuiltInFunction, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(func);
// Note that name needs to be set before texture function type is determined.
callNode->setBuiltInFunctionPrecision();
return callNode;
}
TIntermAggregate *TIntermAggregate::CreateConstructor(const TType &type,
TOperator op,
TIntermSequence *arguments)
{
TIntermAggregate *constructorNode = new TIntermAggregate(type, op, arguments);
ASSERT(constructorNode->isConstructor());
return constructorNode;
}
TIntermAggregate *TIntermAggregate::Create(const TType &type,
TOperator op,
TIntermSequence *arguments)
{
TIntermAggregate *node = new TIntermAggregate(type, op, arguments);
ASSERT(op != EOpCallFunctionInAST); // Should use CreateFunctionCall
ASSERT(op != EOpCallBuiltInFunction); // Should use CreateBuiltInFunctionCall
ASSERT(!node->isConstructor()); // Should use CreateConstructor
return node;
}
TIntermAggregate::TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments) TIntermAggregate::TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments)
: TIntermOperator(op), mUseEmulatedFunction(false), mGotPrecisionFromChildren(false) : TIntermOperator(op), mUseEmulatedFunction(false), mGotPrecisionFromChildren(false)
{ {
...@@ -557,7 +608,8 @@ TIntermTyped *TIntermTyped::CreateZero(const TType &type) ...@@ -557,7 +608,8 @@ TIntermTyped *TIntermTyped::CreateZero(const TType &type)
} }
} }
return new TIntermAggregate(constType, sh::TypeToConstructorOperator(type), arguments); return TIntermAggregate::CreateConstructor(constType, sh::TypeToConstructorOperator(type),
arguments);
} }
// static // static
...@@ -579,7 +631,45 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T ...@@ -579,7 +631,45 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T
void TFunctionSymbolInfo::setFromFunction(const TFunction &function) void TFunctionSymbolInfo::setFromFunction(const TFunction &function)
{ {
setName(function.getMangledName()); setName(function.getMangledName());
setId(function.getUniqueId()); setId(TSymbolUniqueId(function));
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TSymbolUniqueId &id) : mId(new TSymbolUniqueId(id))
{
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TFunctionSymbolInfo &info)
: mName(info.mName), mId(nullptr)
{
if (info.mId)
{
mId = new TSymbolUniqueId(*info.mId);
}
}
TFunctionSymbolInfo &TFunctionSymbolInfo::operator=(const TFunctionSymbolInfo &info)
{
mName = info.mName;
if (info.mId)
{
mId = new TSymbolUniqueId(*info.mId);
}
else
{
mId = nullptr;
}
return *this;
}
void TFunctionSymbolInfo::setId(const TSymbolUniqueId &id)
{
mId = new TSymbolUniqueId(id);
}
const TSymbolUniqueId &TFunctionSymbolInfo::getId() const
{
ASSERT(mId);
return *mId;
} }
TIntermAggregate::TIntermAggregate(const TIntermAggregate &node) TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
...@@ -597,6 +687,16 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node) ...@@ -597,6 +687,16 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
} }
} }
TIntermAggregate *TIntermAggregate::shallowCopy() const
{
TIntermSequence *copySeq = new TIntermSequence();
copySeq->insert(copySeq->begin(), getSequence()->begin(), getSequence()->end());
TIntermAggregate *copyNode = new TIntermAggregate(mType, mOp, copySeq);
*copyNode->getFunctionSymbolInfo() = mFunctionInfo;
copyNode->setLine(mLine);
return copyNode;
}
TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermTyped(node) TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermTyped(node)
{ {
TIntermTyped *operandCopy = node.mOperand->deepCopy(); TIntermTyped *operandCopy = node.mOperand->deepCopy();
...@@ -3276,4 +3376,45 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent, ...@@ -3276,4 +3376,45 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild)); mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild));
} }
TName TIntermTraverser::GetInternalFunctionName(const char *name)
{
TString nameStr(name);
nameStr = TFunction::mangleName(nameStr);
TName nameObj(nameStr);
nameObj.setInternal(true);
return nameObj;
}
TIntermFunctionPrototype *TIntermTraverser::CreateInternalFunctionPrototypeNode(
const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId)
{
TIntermFunctionPrototype *functionNode = new TIntermFunctionPrototype(returnType, functionId);
functionNode->getFunctionSymbolInfo()->setNameObj(GetInternalFunctionName(name));
return functionNode;
}
TIntermFunctionDefinition *TIntermTraverser::CreateInternalFunctionDefinitionNode(
const TType &returnType,
const char *name,
TIntermBlock *functionBody,
const TSymbolUniqueId &functionId)
{
TIntermFunctionPrototype *prototypeNode =
CreateInternalFunctionPrototypeNode(returnType, name, functionId);
return new TIntermFunctionDefinition(prototypeNode, functionBody);
}
TIntermAggregate *TIntermTraverser::CreateInternalFunctionCallNode(
const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId,
TIntermSequence *arguments)
{
TIntermAggregate *functionNode = TIntermAggregate::CreateFunctionCall(
returnType, functionId, GetInternalFunctionName(name), arguments);
return functionNode;
}
} // namespace sh } // namespace sh
...@@ -56,6 +56,7 @@ class TIntermRaw; ...@@ -56,6 +56,7 @@ class TIntermRaw;
class TIntermBranch; class TIntermBranch;
class TSymbolTable; class TSymbolTable;
class TSymbolUniqueId;
class TFunction; class TFunction;
// Encapsulate an identifier string and track whether it is coming from the original shader code // Encapsulate an identifier string and track whether it is coming from the original shader code
...@@ -414,7 +415,7 @@ class TIntermOperator : public TIntermTyped ...@@ -414,7 +415,7 @@ class TIntermOperator : public TIntermTyped
TIntermOperator(const TIntermOperator &) = default; TIntermOperator(const TIntermOperator &) = default;
TOperator mOp; const TOperator mOp;
}; };
// Node for vector swizzles. // Node for vector swizzles.
...@@ -535,10 +536,11 @@ class TFunctionSymbolInfo ...@@ -535,10 +536,11 @@ class TFunctionSymbolInfo
{ {
public: public:
POOL_ALLOCATOR_NEW_DELETE(); POOL_ALLOCATOR_NEW_DELETE();
TFunctionSymbolInfo() : mId(0) {} TFunctionSymbolInfo(const TSymbolUniqueId &id);
TFunctionSymbolInfo() : mId(nullptr) {}
TFunctionSymbolInfo(const TFunctionSymbolInfo &) = default; TFunctionSymbolInfo(const TFunctionSymbolInfo &info);
TFunctionSymbolInfo &operator=(const TFunctionSymbolInfo &) = default; TFunctionSymbolInfo &operator=(const TFunctionSymbolInfo &info);
void setFromFunction(const TFunction &function); void setFromFunction(const TFunction &function);
...@@ -550,11 +552,12 @@ class TFunctionSymbolInfo ...@@ -550,11 +552,12 @@ class TFunctionSymbolInfo
void setName(const TString &name) { mName.setString(name); } void setName(const TString &name) { mName.setString(name); }
bool isMain() const { return mName.getString() == "main("; } bool isMain() const { return mName.getString() == "main("; }
void setId(int functionId) { mId = functionId; } void setId(const TSymbolUniqueId &functionId);
int getId() const { return mId; } const TSymbolUniqueId &getId() const;
private: private:
TName mName; TName mName;
int mId; TSymbolUniqueId *mId;
}; };
typedef TVector<TIntermNode *> TIntermSequence; typedef TVector<TIntermNode *> TIntermSequence;
...@@ -593,12 +596,28 @@ class TIntermAggregateBase ...@@ -593,12 +596,28 @@ class TIntermAggregateBase
class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
{ {
public: public:
TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments); static TIntermAggregate *CreateFunctionCall(const TFunction &func, TIntermSequence *arguments);
// If using this, ensure that there's a consistent function definition with the same symbol id
// added to the AST.
static TIntermAggregate *CreateFunctionCall(const TType &type,
const TSymbolUniqueId &id,
const TName &name,
TIntermSequence *arguments);
static TIntermAggregate *CreateBuiltInFunctionCall(const TFunction &func,
TIntermSequence *arguments);
static TIntermAggregate *CreateConstructor(const TType &type,
TOperator op,
TIntermSequence *arguments);
static TIntermAggregate *Create(const TType &type, TOperator op, TIntermSequence *arguments);
~TIntermAggregate() {} ~TIntermAggregate() {}
// Note: only supported for nodes that can be a part of an expression. // Note: only supported for nodes that can be a part of an expression.
TIntermTyped *deepCopy() const override { return new TIntermAggregate(*this); } TIntermTyped *deepCopy() const override { return new TIntermAggregate(*this); }
TIntermAggregate *shallowCopy() const;
TIntermAggregate *getAsAggregate() override { return this; } TIntermAggregate *getAsAggregate() override { return this; }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override; bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
...@@ -619,10 +638,6 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -619,10 +638,6 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; } TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; } const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
// Used for built-in functions under EOpCallBuiltInFunction. The function name in the symbol
// info needs to be set before calling this.
void setBuiltInFunctionPrecision();
protected: protected:
TIntermSequence mArguments; TIntermSequence mArguments;
...@@ -635,6 +650,8 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -635,6 +650,8 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TFunctionSymbolInfo mFunctionInfo; TFunctionSymbolInfo mFunctionInfo;
private: private:
TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments);
TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private! TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private!
void setTypePrecisionAndQualifier(const TType &type); void setTypePrecisionAndQualifier(const TType &type);
...@@ -647,6 +664,10 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -647,6 +664,10 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
// Returns true if precision was set according to special rules for this built-in. // Returns true if precision was set according to special rules for this built-in.
bool setPrecisionForSpecialBuiltInOp(); bool setPrecisionForSpecialBuiltInOp();
// Used for built-in functions under EOpCallBuiltInFunction. The function name in the symbol
// info needs to be set before calling this.
void setBuiltInFunctionPrecision();
}; };
// A list of statements. Either the root node which contains declarations and function definitions, // A list of statements. Either the root node which contains declarations and function definitions,
...@@ -678,7 +699,10 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas ...@@ -678,7 +699,10 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas
public: public:
// TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor // TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor
// parameters. // parameters.
TIntermFunctionPrototype(const TType &type) : TIntermTyped(type) {} TIntermFunctionPrototype(const TType &type, const TSymbolUniqueId &id)
: TIntermTyped(type), mFunctionInfo(id)
{
}
~TIntermFunctionPrototype() {} ~TIntermFunctionPrototype() {}
TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; } TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; }
...@@ -967,6 +991,20 @@ class TIntermTraverser : angle::NonCopyable ...@@ -967,6 +991,20 @@ class TIntermTraverser : angle::NonCopyable
// Start creating temporary symbols from the given temporary symbol index + 1. // Start creating temporary symbols from the given temporary symbol index + 1.
void useTemporaryIndex(unsigned int *temporaryIndex); void useTemporaryIndex(unsigned int *temporaryIndex);
static TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(
const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId);
static TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(
const TType &returnType,
const char *name,
TIntermBlock *functionBody,
const TSymbolUniqueId &functionId);
static TIntermAggregate *CreateInternalFunctionCallNode(const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId,
TIntermSequence *arguments);
protected: protected:
// Should only be called from traverse*() functions // Should only be called from traverse*() functions
void incrementDepth(TIntermNode *current) void incrementDepth(TIntermNode *current)
...@@ -1112,6 +1150,8 @@ class TIntermTraverser : angle::NonCopyable ...@@ -1112,6 +1150,8 @@ class TIntermTraverser : angle::NonCopyable
std::vector<NodeInsertMultipleEntry> mInsertions; std::vector<NodeInsertMultipleEntry> mInsertions;
private: private:
static TName GetInternalFunctionName(const char *name);
// To replace a single node with another on the parent node // To replace a single node with another on the parent node
struct NodeUpdateEntry struct NodeUpdateEntry
{ {
...@@ -1195,7 +1235,7 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -1195,7 +1235,7 @@ class TLValueTrackingTraverser : public TIntermTraverser
bool operatorRequiresLValue() const { return mOperatorRequiresLValue; } bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
// Add a function encountered during traversal to the function map. // Add a function encountered during traversal to the function map.
void addToFunctionMap(const TName &name, TIntermSequence *paramSequence); void addToFunctionMap(const TSymbolUniqueId &id, TIntermSequence *paramSequence);
// Return true if the prototype or definition of the function being called has been encountered // Return true if the prototype or definition of the function being called has been encountered
// during traversal. // during traversal.
...@@ -1211,20 +1251,8 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -1211,20 +1251,8 @@ class TLValueTrackingTraverser : public TIntermTraverser
bool mOperatorRequiresLValue; bool mOperatorRequiresLValue;
bool mInFunctionCallOutParameter; bool mInFunctionCallOutParameter;
struct TNameComparator // Map from function symbol id values to their parameter sequences
{ TMap<int, TIntermSequence *> mFunctionMap;
bool operator()(const TName &a, const TName &b) const
{
int compareResult = a.getString().compare(b.getString());
if (compareResult != 0)
return compareResult < 0;
// Internal functions may have same names as non-internal functions.
return !a.isInternal() && b.isInternal();
}
};
// Map from mangled function names to their parameter sequences
TMap<TName, TIntermSequence *, TNameComparator> mFunctionMap;
const TSymbolTable &mSymbolTable; const TSymbolTable &mSymbolTable;
const int mShaderVersion; const int mShaderVersion;
......
...@@ -228,22 +228,23 @@ void TIntermTraverser::nextTemporaryIndex() ...@@ -228,22 +228,23 @@ void TIntermTraverser::nextTemporaryIndex()
++(*mTemporaryIndex); ++(*mTemporaryIndex);
} }
void TLValueTrackingTraverser::addToFunctionMap(const TName &name, TIntermSequence *paramSequence) void TLValueTrackingTraverser::addToFunctionMap(const TSymbolUniqueId &id,
TIntermSequence *paramSequence)
{ {
mFunctionMap[name] = paramSequence; mFunctionMap[id.get()] = paramSequence;
} }
bool TLValueTrackingTraverser::isInFunctionMap(const TIntermAggregate *callNode) const bool TLValueTrackingTraverser::isInFunctionMap(const TIntermAggregate *callNode) const
{ {
ASSERT(callNode->getOp() == EOpCallFunctionInAST); ASSERT(callNode->getOp() == EOpCallFunctionInAST);
return (mFunctionMap.find(callNode->getFunctionSymbolInfo()->getNameObj()) != return (mFunctionMap.find(callNode->getFunctionSymbolInfo()->getId().get()) !=
mFunctionMap.end()); mFunctionMap.end());
} }
TIntermSequence *TLValueTrackingTraverser::getFunctionParameters(const TIntermAggregate *callNode) TIntermSequence *TLValueTrackingTraverser::getFunctionParameters(const TIntermAggregate *callNode)
{ {
ASSERT(isInFunctionMap(callNode)); ASSERT(isInFunctionMap(callNode));
return mFunctionMap[callNode->getFunctionSymbolInfo()->getNameObj()]; return mFunctionMap[callNode->getFunctionSymbolInfo()->getId().get()];
} }
void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter) void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
...@@ -623,7 +624,7 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -623,7 +624,7 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node) void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{ {
TIntermSequence *sequence = node->getSequence(); TIntermSequence *sequence = node->getSequence();
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence); addToFunctionMap(node->getFunctionSymbolInfo()->getId(), sequence);
TIntermTraverser::traverseFunctionPrototype(node); TIntermTraverser::traverseFunctionPrototype(node);
} }
......
...@@ -2490,7 +2490,8 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction( ...@@ -2490,7 +2490,8 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction(
const TSourceLoc &location, const TSourceLoc &location,
bool insertParametersToSymbolTable) bool insertParametersToSymbolTable)
{ {
TIntermFunctionPrototype *prototype = new TIntermFunctionPrototype(function.getReturnType()); TIntermFunctionPrototype *prototype =
new TIntermFunctionPrototype(function.getReturnType(), TSymbolUniqueId(function));
// TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could // TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could
// point to the data that already exists in the symbol table. // point to the data that already exists in the symbol table.
prototype->getFunctionSymbolInfo()->setFromFunction(function); prototype->getFunctionSymbolInfo()->setFromFunction(function);
...@@ -2803,9 +2804,8 @@ TIntermTyped *TParseContext::addConstructor(TIntermSequence *arguments, ...@@ -2803,9 +2804,8 @@ TIntermTyped *TParseContext::addConstructor(TIntermSequence *arguments,
return TIntermTyped::CreateZero(type); return TIntermTyped::CreateZero(type);
} }
TIntermAggregate *constructorNode = new TIntermAggregate(type, op, arguments); TIntermAggregate *constructorNode = TIntermAggregate::CreateConstructor(type, op, arguments);
constructorNode->setLine(line); constructorNode->setLine(line);
ASSERT(constructorNode->isConstructor());
TIntermTyped *constConstructor = TIntermTyped *constConstructor =
intermediate.foldAggregateBuiltIn(constructorNode, mDiagnostics); intermediate.foldAggregateBuiltIn(constructorNode, mDiagnostics);
...@@ -4533,7 +4533,7 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall, ...@@ -4533,7 +4533,7 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall,
else else
{ {
TIntermAggregate *callNode = TIntermAggregate *callNode =
new TIntermAggregate(fnCandidate->getReturnType(), op, arguments); TIntermAggregate::Create(fnCandidate->getReturnType(), op, arguments);
callNode->setLine(loc); callNode->setLine(loc);
// Some built-in functions have out parameters too. // Some built-in functions have out parameters too.
...@@ -4561,19 +4561,13 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall, ...@@ -4561,19 +4561,13 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall,
// This needs to happen after the function info including name is set. // This needs to happen after the function info including name is set.
if (builtIn) if (builtIn)
{ {
callNode = new TIntermAggregate(fnCandidate->getReturnType(), callNode = TIntermAggregate::CreateBuiltInFunctionCall(*fnCandidate, arguments);
EOpCallBuiltInFunction, arguments);
// Note that name needs to be set before texture function type is determined.
callNode->getFunctionSymbolInfo()->setFromFunction(*fnCandidate);
callNode->setBuiltInFunctionPrecision();
checkTextureOffsetConst(callNode); checkTextureOffsetConst(callNode);
checkImageMemoryAccessForBuiltinFunctions(callNode); checkImageMemoryAccessForBuiltinFunctions(callNode);
} }
else else
{ {
callNode = new TIntermAggregate(fnCandidate->getReturnType(), callNode = TIntermAggregate::CreateFunctionCall(*fnCandidate, arguments);
EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(*fnCandidate);
checkImageMemoryAccessForUserDefinedFunctions(fnCandidate, callNode); checkImageMemoryAccessForUserDefinedFunctions(fnCandidate, callNode);
} }
......
...@@ -89,8 +89,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -89,8 +89,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
16, node->getFunctionSymbolInfo()->getName().length() - 20); 16, node->getFunctionSymbolInfo()->getName().length() - 20);
TString newName = "texelFetch" + newArgs; TString newName = "texelFetch" + newArgs;
TSymbol *texelFetchSymbol = symbolTable->findBuiltIn(newName, shaderVersion); TSymbol *texelFetchSymbol = symbolTable->findBuiltIn(newName, shaderVersion);
ASSERT(texelFetchSymbol); ASSERT(texelFetchSymbol && texelFetchSymbol->isFunction());
int uniqueId = texelFetchSymbol->getUniqueId();
// Create new node that represents the call of function texelFetch. // Create new node that represents the call of function texelFetch.
// Its argument list will be: texelFetch(sampler, Position+offset, lod). // Its argument list will be: texelFetch(sampler, Position+offset, lod).
...@@ -117,8 +116,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -117,8 +116,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
TIntermTyped *zeroNode = TIntermTyped::CreateZero(TType(EbtInt)); TIntermTyped *zeroNode = TIntermTyped::CreateZero(TType(EbtInt));
constructOffsetIvecArguments->push_back(zeroNode); constructOffsetIvecArguments->push_back(zeroNode);
offsetNode = new TIntermAggregate(texCoordNode->getType(), EOpConstructIVec3, offsetNode = TIntermAggregate::CreateConstructor(texCoordNode->getType(), EOpConstructIVec3,
constructOffsetIvecArguments); constructOffsetIvecArguments);
offsetNode->setLine(texCoordNode->getLine()); offsetNode->setLine(texCoordNode->getLine());
} }
else else
...@@ -136,10 +135,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -136,10 +135,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
ASSERT(texelFetchArguments->size() == 3u); ASSERT(texelFetchArguments->size() == 3u);
TIntermAggregate *texelFetchNode = TIntermAggregate *texelFetchNode = TIntermAggregate::CreateBuiltInFunctionCall(
new TIntermAggregate(node->getType(), EOpCallBuiltInFunction, texelFetchArguments); *static_cast<const TFunction *>(texelFetchSymbol), texelFetchArguments);
texelFetchNode->getFunctionSymbolInfo()->setName(newName);
texelFetchNode->getFunctionSymbolInfo()->setId(uniqueId);
texelFetchNode->setLine(node->getLine()); texelFetchNode->setLine(node->getLine());
// Replace the old node by this new node. // Replace the old node by this new node.
......
...@@ -55,16 +55,6 @@ TIntermBinary *CopyAssignmentNode(TIntermBinary *node) ...@@ -55,16 +55,6 @@ TIntermBinary *CopyAssignmentNode(TIntermBinary *node)
return new TIntermBinary(node->getOp(), node->getLeft(), node->getRight()); return new TIntermBinary(node->getOp(), node->getLeft(), node->getRight());
} }
// Performs a shallow copy of a constructor/function call node.
TIntermAggregate *CopyAggregateNode(TIntermAggregate *node)
{
TIntermSequence *copySeq = new TIntermSequence();
copySeq->insert(copySeq->begin(), node->getSequence()->begin(), node->getSequence()->end());
TIntermAggregate *copyNode = new TIntermAggregate(node->getType(), node->getOp(), copySeq);
*copyNode->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
return copyNode;
}
bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node) bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
if (mFoundArrayExpression) if (mFoundArrayExpression)
...@@ -104,7 +94,7 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate ...@@ -104,7 +94,7 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate
mFoundArrayExpression = true; mFoundArrayExpression = true;
TIntermSequence insertions; TIntermSequence insertions;
insertions.push_back(createTempInitDeclaration(CopyAggregateNode(node))); insertions.push_back(createTempInitDeclaration(node->shallowCopy()));
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
queueReplacement(node, createTempSymbol(node->getType()), OriginalNode::IS_DROPPED); queueReplacement(node, createTempSymbol(node->getType()), OriginalNode::IS_DROPPED);
......
...@@ -26,6 +26,19 @@ namespace sh ...@@ -26,6 +26,19 @@ namespace sh
int TSymbolTable::uniqueIdCounter = 0; int TSymbolTable::uniqueIdCounter = 0;
TSymbolUniqueId::TSymbolUniqueId() : mId(TSymbolTable::nextUniqueId())
{
}
TSymbolUniqueId::TSymbolUniqueId(const TSymbol &symbol) : mId(symbol.getUniqueId())
{
}
int TSymbolUniqueId::get() const
{
return mId;
}
TSymbol::TSymbol(const TString *n) : uniqueId(TSymbolTable::nextUniqueId()), name(n) TSymbol::TSymbol(const TString *n) : uniqueId(TSymbolTable::nextUniqueId()), name(n)
{ {
} }
......
...@@ -41,6 +41,22 @@ ...@@ -41,6 +41,22 @@
namespace sh namespace sh
{ {
// Encapsulates a unique id for a symbol.
class TSymbolUniqueId
{
public:
POOL_ALLOCATOR_NEW_DELETE();
TSymbolUniqueId();
TSymbolUniqueId(const TSymbol &symbol);
TSymbolUniqueId(const TSymbolUniqueId &) = default;
TSymbolUniqueId &operator=(const TSymbolUniqueId &) = default;
int get() const;
private:
int mId;
};
// Symbol base class. (Can build functions or variables out of these...) // Symbol base class. (Can build functions or variables out of these...)
class TSymbol : angle::NonCopyable class TSymbol : angle::NonCopyable
{ {
......
...@@ -17,7 +17,7 @@ void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *in ...@@ -17,7 +17,7 @@ void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *in
{ {
const char *internal = info->getNameObj().isInternal() ? " (internal function)" : ""; const char *internal = info->getNameObj().isInternal() ? " (internal function)" : "";
out << str << internal << ": " << info->getNameObj().getString() << " (symbol id " out << str << internal << ": " << info->getNameObj().getString() << " (symbol id "
<< info->getId() << ")"; << info->getId().get() << ")";
} }
// //
......
...@@ -189,7 +189,7 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode) ...@@ -189,7 +189,7 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode)
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
TIntermAggregate *original = TIntermAggregate *original =
new TIntermAggregate(originalSeq->at(0)->getAsTyped()->getType(), EOpMix, originalSeq); TIntermAggregate::Create(originalSeq->at(0)->getAsTyped()->getType(), EOpMix, originalSeq);
original->setLine(getTestSourceLoc()); original->setLine(getTestSourceLoc());
TIntermTyped *copyTyped = original->deepCopy(); TIntermTyped *copyTyped = original->deepCopy();
......
...@@ -69,7 +69,7 @@ TEST_F(PruneUnusedFunctionsTest, UnimplementedPrototype) ...@@ -69,7 +69,7 @@ TEST_F(PruneUnusedFunctionsTest, UnimplementedPrototype)
EXPECT_TRUE(foundInCode("main(", 1)); EXPECT_TRUE(foundInCode("main(", 1));
} }
// Check that used functions are not prunued (duh) // Check that used functions are not pruned (duh)
TEST_F(PruneUnusedFunctionsTest, UsedFunction) TEST_F(PruneUnusedFunctionsTest, UsedFunction)
{ {
const std::string &shaderString = const std::string &shaderString =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment