Commit fe48632f by Olli Etuaho Committed by Commit Bot

Prefer identifying functions by using symbol ids

The shader translator code is now structured in a way that ensures that all function definition, function prototype and function call nodes store the integer symbol id for the function. This is guaranteed regardless of whether the function node is added while parsing or as a result of an AST transformation. TIntermAggregate nodes, which include function calls and constructors can now only be created by calling one of the TIntermAggregate::Create*() functions to ensure they have all the necessary properties. This makes it possible to keep track of functions using integer ids instead of their mangled name strings when generating the call graph and when using TLValueTrackingTraverser. This commit includes a few other small cleanups to the CallDAG class as well. BUG=angleproject:1490 TEST=angle_unittests, angle_end2end_tests Change-Id: Idd1013506cbe4c3380e20d90524a9cd09b890259 Reviewed-on: https://chromium-review.googlesource.com/459603Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 76e6565e
...@@ -50,9 +50,9 @@ TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall, ...@@ -50,9 +50,9 @@ TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall,
replacementArguments->push_back(arg); replacementArguments->push_back(arg);
} }
replacementArguments->push_back(returnValueTarget); replacementArguments->push_back(returnValueTarget);
TIntermAggregate *replacementCall = TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
new TIntermAggregate(TType(EbtVoid), EOpCallFunctionInAST, replacementArguments); TType(EbtVoid), originalCall->getFunctionSymbolInfo()->getId(),
*replacementCall->getFunctionSymbolInfo() = *originalCall->getFunctionSymbolInfo(); originalCall->getFunctionSymbolInfo()->getNameObj(), replacementArguments);
replacementCall->setLine(originalCall->getLine()); replacementCall->setLine(originalCall->getLine());
return replacementCall; return replacementCall;
} }
...@@ -110,7 +110,8 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit ...@@ -110,7 +110,8 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit
{ {
// Replace the whole prototype node with another node that has the out parameter // Replace the whole prototype node with another node that has the out parameter
// added. Also set the function to return void. // added. Also set the function to return void.
TIntermFunctionPrototype *replacement = new TIntermFunctionPrototype(TType(EbtVoid)); TIntermFunctionPrototype *replacement =
new TIntermFunctionPrototype(TType(EbtVoid), node->getFunctionSymbolInfo()->getId());
CopyAggregateChildren(node, replacement); CopyAggregateChildren(node, replacement);
replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType())); replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType()));
*replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo(); *replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
// order. // order.
#include "compiler/translator/CallDAG.h" #include "compiler/translator/CallDAG.h"
#include "compiler/translator/Diagnostics.h" #include "compiler/translator/Diagnostics.h"
#include "compiler/translator/SymbolTable.h"
namespace sh namespace sh
{ {
...@@ -78,7 +80,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -78,7 +80,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
record.callees.push_back(static_cast<int>(callee->index)); record.callees.push_back(static_cast<int>(callee->index));
} }
(*idToIndex)[data.node->getFunctionSymbolInfo()->getId()] = (*idToIndex)[data.node->getFunctionSymbolInfo()->getId().get()] =
static_cast<int>(data.index); static_cast<int>(data.index);
} }
} }
...@@ -101,19 +103,20 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -101,19 +103,20 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
// Create the record if need be and remember the node. // Create the record if need be and remember the node.
if (visit == PreVisit) if (visit == PreVisit)
{ {
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName()); auto it = mFunctions.find(node->getFunctionSymbolInfo()->getId().get());
if (it == mFunctions.end()) if (it == mFunctions.end())
{ {
mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getName()]; mCurrentFunction = &mFunctions[node->getFunctionSymbolInfo()->getId().get()];
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
} }
else else
{ {
mCurrentFunction = &it->second; mCurrentFunction = &it->second;
ASSERT(mCurrentFunction->name == node->getFunctionSymbolInfo()->getName());
} }
mCurrentFunction->node = node; mCurrentFunction->node = node;
mCurrentFunction->name = node->getFunctionSymbolInfo()->getName();
} }
else if (visit == PostVisit) else if (visit == PostVisit)
{ {
...@@ -125,8 +128,13 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -125,8 +128,13 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override
{ {
ASSERT(visit == PreVisit); ASSERT(visit == PreVisit);
if (mCurrentFunction != nullptr)
{
return false;
}
// Function declaration, create an empty record. // Function declaration, create an empty record.
auto &record = mFunctions[node->getFunctionSymbolInfo()->getName()]; auto &record = mFunctions[node->getFunctionSymbolInfo()->getId().get()];
record.name = node->getFunctionSymbolInfo()->getName(); record.name = node->getFunctionSymbolInfo()->getName();
// No need to traverse the parameters. // No need to traverse the parameters.
...@@ -139,10 +147,12 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -139,10 +147,12 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST) if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
{ {
// Function call, add the callees // Function call, add the callees
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getName()); auto it = mFunctions.find(node->getFunctionSymbolInfo()->getId().get());
ASSERT(it != mFunctions.end()); ASSERT(it != mFunctions.end());
// We might be in a top-level function call to set a global variable // We might be traversing the initializer of a global variable. Even though function
// calls in global scope are forbidden by the parser, some subsequent AST
// transformations can add them to emulate particular features.
if (mCurrentFunction) if (mCurrentFunction)
{ {
mCurrentFunction->callees.insert(&it->second); mCurrentFunction->callees.insert(&it->second);
...@@ -259,7 +269,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser ...@@ -259,7 +269,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
TDiagnostics *mDiagnostics; TDiagnostics *mDiagnostics;
std::map<TString, CreatorFunctionData> mFunctions; std::map<int, CreatorFunctionData> mFunctions;
CreatorFunctionData *mCurrentFunction; CreatorFunctionData *mCurrentFunction;
size_t mCurrentIndex; size_t mCurrentIndex;
}; };
...@@ -278,7 +288,7 @@ const size_t CallDAG::InvalidIndex = std::numeric_limits<size_t>::max(); ...@@ -278,7 +288,7 @@ const size_t CallDAG::InvalidIndex = std::numeric_limits<size_t>::max();
size_t CallDAG::findIndex(const TFunctionSymbolInfo *functionInfo) const size_t CallDAG::findIndex(const TFunctionSymbolInfo *functionInfo) const
{ {
auto it = mFunctionIdToIndex.find(functionInfo->getId()); auto it = mFunctionIdToIndex.find(functionInfo->getId().get());
if (it == mFunctionIdToIndex.end()) if (it == mFunctionIdToIndex.end())
{ {
......
...@@ -21,43 +21,6 @@ namespace sh ...@@ -21,43 +21,6 @@ namespace sh
namespace namespace
{ {
void SetInternalFunctionName(TFunctionSymbolInfo *functionInfo, const char *name)
{
TString nameStr(name);
nameStr = TFunction::mangleName(nameStr);
TName nameObj(nameStr);
nameObj.setInternal(true);
functionInfo->setNameObj(nameObj);
}
TIntermFunctionPrototype *CreateFunctionPrototypeNode(const char *name, const int functionId)
{
TType returnType(EbtVoid);
TIntermFunctionPrototype *functionNode = new TIntermFunctionPrototype(returnType);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode;
}
TIntermFunctionDefinition *CreateFunctionDefinitionNode(const char *name,
TIntermBlock *functionBody,
const int functionId)
{
TIntermFunctionPrototype *prototypeNode = CreateFunctionPrototypeNode(name, functionId);
return new TIntermFunctionDefinition(prototypeNode, functionBody);
}
TIntermAggregate *CreateFunctionCallNode(const char *name, const int functionId)
{
TType returnType(EbtVoid);
TIntermAggregate *functionNode =
new TIntermAggregate(returnType, EOpCallFunctionInAST, nullptr);
SetInternalFunctionName(functionNode->getFunctionSymbolInfo(), name);
functionNode->getFunctionSymbolInfo()->setId(functionId);
return functionNode;
}
class DeferGlobalInitializersTraverser : public TIntermTraverser class DeferGlobalInitializersTraverser : public TIntermTraverser
{ {
public: public:
...@@ -132,13 +95,13 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -132,13 +95,13 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
{ {
return; return;
} }
const int initFunctionId = TSymbolTable::nextUniqueId(); TSymbolUniqueId initFunctionId;
const char *functionName = "initializeDeferredGlobals"; const char *functionName = "initializeDeferredGlobals";
// Add function prototype to the beginning of the shader // Add function prototype to the beginning of the shader
TIntermFunctionPrototype *functionPrototypeNode = TIntermFunctionPrototype *functionPrototypeNode =
CreateFunctionPrototypeNode(functionName, initFunctionId); CreateInternalFunctionPrototypeNode(TType(EbtVoid), functionName, initFunctionId);
root->getSequence()->insert(root->getSequence()->begin(), functionPrototypeNode); root->getSequence()->insert(root->getSequence()->begin(), functionPrototypeNode);
// Add function definition to the end of the shader // Add function definition to the end of the shader
...@@ -148,8 +111,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -148,8 +111,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
{ {
functionBody->push_back(deferredInit); functionBody->push_back(deferredInit);
} }
TIntermFunctionDefinition *functionDefinition = TIntermFunctionDefinition *functionDefinition = CreateInternalFunctionDefinitionNode(
CreateFunctionDefinitionNode(functionName, functionBodyNode, initFunctionId); TType(EbtVoid), functionName, functionBodyNode, initFunctionId);
root->getSequence()->push_back(functionDefinition); root->getSequence()->push_back(functionDefinition);
// Insert call into main function // Insert call into main function
...@@ -158,8 +121,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root) ...@@ -158,8 +121,8 @@ void DeferGlobalInitializersTraverser::insertInitFunction(TIntermBlock *root)
TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition(); TIntermFunctionDefinition *nodeFunction = node->getAsFunctionDefinition();
if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain()) if (nodeFunction != nullptr && nodeFunction->getFunctionSymbolInfo()->isMain())
{ {
TIntermAggregate *functionCallNode = TIntermAggregate *functionCallNode = CreateInternalFunctionCallNode(
CreateFunctionCallNode(functionName, initFunctionId); TType(EbtVoid), functionName, initFunctionId, nullptr);
TIntermBlock *mainBody = nodeFunction->getBody(); TIntermBlock *mainBody = nodeFunction->getBody();
ASSERT(mainBody != nullptr); ASSERT(mainBody != nullptr);
......
...@@ -433,7 +433,8 @@ TIntermAggregate *createInternalFunctionCallNode(const TType &type, ...@@ -433,7 +433,8 @@ TIntermAggregate *createInternalFunctionCallNode(const TType &type,
{ {
TName nameObj(TFunction::GetMangledNameFromCall(name, *arguments)); TName nameObj(TFunction::GetMangledNameFromCall(name, *arguments));
nameObj.setInternal(true); nameObj.setInternal(true);
TIntermAggregate *callNode = new TIntermAggregate(type, EOpCallInternalRawFunction, arguments); TIntermAggregate *callNode =
TIntermAggregate::Create(type, EOpCallInternalRawFunction, arguments);
callNode->getFunctionSymbolInfo()->setNameObj(nameObj); callNode->getFunctionSymbolInfo()->setNameObj(nameObj);
return callNode; return callNode;
} }
......
...@@ -272,6 +272,57 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position, ...@@ -272,6 +272,57 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position,
return true; return true;
} }
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode =
new TIntermAggregate(func.getReturnType(), EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(func);
return callNode;
}
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TType &type,
const TSymbolUniqueId &id,
const TName &name,
TIntermSequence *arguments)
{
TIntermAggregate *callNode = new TIntermAggregate(type, EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setId(id);
callNode->getFunctionSymbolInfo()->setNameObj(name);
return callNode;
}
TIntermAggregate *TIntermAggregate::CreateBuiltInFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode =
new TIntermAggregate(func.getReturnType(), EOpCallBuiltInFunction, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(func);
// Note that name needs to be set before texture function type is determined.
callNode->setBuiltInFunctionPrecision();
return callNode;
}
TIntermAggregate *TIntermAggregate::CreateConstructor(const TType &type,
TOperator op,
TIntermSequence *arguments)
{
TIntermAggregate *constructorNode = new TIntermAggregate(type, op, arguments);
ASSERT(constructorNode->isConstructor());
return constructorNode;
}
TIntermAggregate *TIntermAggregate::Create(const TType &type,
TOperator op,
TIntermSequence *arguments)
{
TIntermAggregate *node = new TIntermAggregate(type, op, arguments);
ASSERT(op != EOpCallFunctionInAST); // Should use CreateFunctionCall
ASSERT(op != EOpCallBuiltInFunction); // Should use CreateBuiltInFunctionCall
ASSERT(!node->isConstructor()); // Should use CreateConstructor
return node;
}
TIntermAggregate::TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments) TIntermAggregate::TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments)
: TIntermOperator(op), mUseEmulatedFunction(false), mGotPrecisionFromChildren(false) : TIntermOperator(op), mUseEmulatedFunction(false), mGotPrecisionFromChildren(false)
{ {
...@@ -557,7 +608,8 @@ TIntermTyped *TIntermTyped::CreateZero(const TType &type) ...@@ -557,7 +608,8 @@ TIntermTyped *TIntermTyped::CreateZero(const TType &type)
} }
} }
return new TIntermAggregate(constType, sh::TypeToConstructorOperator(type), arguments); return TIntermAggregate::CreateConstructor(constType, sh::TypeToConstructorOperator(type),
arguments);
} }
// static // static
...@@ -579,7 +631,45 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T ...@@ -579,7 +631,45 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T
void TFunctionSymbolInfo::setFromFunction(const TFunction &function) void TFunctionSymbolInfo::setFromFunction(const TFunction &function)
{ {
setName(function.getMangledName()); setName(function.getMangledName());
setId(function.getUniqueId()); setId(TSymbolUniqueId(function));
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TSymbolUniqueId &id) : mId(new TSymbolUniqueId(id))
{
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TFunctionSymbolInfo &info)
: mName(info.mName), mId(nullptr)
{
if (info.mId)
{
mId = new TSymbolUniqueId(*info.mId);
}
}
TFunctionSymbolInfo &TFunctionSymbolInfo::operator=(const TFunctionSymbolInfo &info)
{
mName = info.mName;
if (info.mId)
{
mId = new TSymbolUniqueId(*info.mId);
}
else
{
mId = nullptr;
}
return *this;
}
void TFunctionSymbolInfo::setId(const TSymbolUniqueId &id)
{
mId = new TSymbolUniqueId(id);
}
const TSymbolUniqueId &TFunctionSymbolInfo::getId() const
{
ASSERT(mId);
return *mId;
} }
TIntermAggregate::TIntermAggregate(const TIntermAggregate &node) TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
...@@ -597,6 +687,16 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node) ...@@ -597,6 +687,16 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
} }
} }
TIntermAggregate *TIntermAggregate::shallowCopy() const
{
TIntermSequence *copySeq = new TIntermSequence();
copySeq->insert(copySeq->begin(), getSequence()->begin(), getSequence()->end());
TIntermAggregate *copyNode = new TIntermAggregate(mType, mOp, copySeq);
*copyNode->getFunctionSymbolInfo() = mFunctionInfo;
copyNode->setLine(mLine);
return copyNode;
}
TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermTyped(node) TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermTyped(node)
{ {
TIntermTyped *operandCopy = node.mOperand->deepCopy(); TIntermTyped *operandCopy = node.mOperand->deepCopy();
...@@ -3276,4 +3376,45 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent, ...@@ -3276,4 +3376,45 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild)); mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild));
} }
TName TIntermTraverser::GetInternalFunctionName(const char *name)
{
TString nameStr(name);
nameStr = TFunction::mangleName(nameStr);
TName nameObj(nameStr);
nameObj.setInternal(true);
return nameObj;
}
TIntermFunctionPrototype *TIntermTraverser::CreateInternalFunctionPrototypeNode(
const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId)
{
TIntermFunctionPrototype *functionNode = new TIntermFunctionPrototype(returnType, functionId);
functionNode->getFunctionSymbolInfo()->setNameObj(GetInternalFunctionName(name));
return functionNode;
}
TIntermFunctionDefinition *TIntermTraverser::CreateInternalFunctionDefinitionNode(
const TType &returnType,
const char *name,
TIntermBlock *functionBody,
const TSymbolUniqueId &functionId)
{
TIntermFunctionPrototype *prototypeNode =
CreateInternalFunctionPrototypeNode(returnType, name, functionId);
return new TIntermFunctionDefinition(prototypeNode, functionBody);
}
TIntermAggregate *TIntermTraverser::CreateInternalFunctionCallNode(
const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId,
TIntermSequence *arguments)
{
TIntermAggregate *functionNode = TIntermAggregate::CreateFunctionCall(
returnType, functionId, GetInternalFunctionName(name), arguments);
return functionNode;
}
} // namespace sh } // namespace sh
...@@ -56,6 +56,7 @@ class TIntermRaw; ...@@ -56,6 +56,7 @@ class TIntermRaw;
class TIntermBranch; class TIntermBranch;
class TSymbolTable; class TSymbolTable;
class TSymbolUniqueId;
class TFunction; class TFunction;
// Encapsulate an identifier string and track whether it is coming from the original shader code // Encapsulate an identifier string and track whether it is coming from the original shader code
...@@ -414,7 +415,7 @@ class TIntermOperator : public TIntermTyped ...@@ -414,7 +415,7 @@ class TIntermOperator : public TIntermTyped
TIntermOperator(const TIntermOperator &) = default; TIntermOperator(const TIntermOperator &) = default;
TOperator mOp; const TOperator mOp;
}; };
// Node for vector swizzles. // Node for vector swizzles.
...@@ -535,10 +536,11 @@ class TFunctionSymbolInfo ...@@ -535,10 +536,11 @@ class TFunctionSymbolInfo
{ {
public: public:
POOL_ALLOCATOR_NEW_DELETE(); POOL_ALLOCATOR_NEW_DELETE();
TFunctionSymbolInfo() : mId(0) {} TFunctionSymbolInfo(const TSymbolUniqueId &id);
TFunctionSymbolInfo() : mId(nullptr) {}
TFunctionSymbolInfo(const TFunctionSymbolInfo &) = default; TFunctionSymbolInfo(const TFunctionSymbolInfo &info);
TFunctionSymbolInfo &operator=(const TFunctionSymbolInfo &) = default; TFunctionSymbolInfo &operator=(const TFunctionSymbolInfo &info);
void setFromFunction(const TFunction &function); void setFromFunction(const TFunction &function);
...@@ -550,11 +552,12 @@ class TFunctionSymbolInfo ...@@ -550,11 +552,12 @@ class TFunctionSymbolInfo
void setName(const TString &name) { mName.setString(name); } void setName(const TString &name) { mName.setString(name); }
bool isMain() const { return mName.getString() == "main("; } bool isMain() const { return mName.getString() == "main("; }
void setId(int functionId) { mId = functionId; } void setId(const TSymbolUniqueId &functionId);
int getId() const { return mId; } const TSymbolUniqueId &getId() const;
private: private:
TName mName; TName mName;
int mId; TSymbolUniqueId *mId;
}; };
typedef TVector<TIntermNode *> TIntermSequence; typedef TVector<TIntermNode *> TIntermSequence;
...@@ -593,12 +596,28 @@ class TIntermAggregateBase ...@@ -593,12 +596,28 @@ class TIntermAggregateBase
class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
{ {
public: public:
TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments); static TIntermAggregate *CreateFunctionCall(const TFunction &func, TIntermSequence *arguments);
// If using this, ensure that there's a consistent function definition with the same symbol id
// added to the AST.
static TIntermAggregate *CreateFunctionCall(const TType &type,
const TSymbolUniqueId &id,
const TName &name,
TIntermSequence *arguments);
static TIntermAggregate *CreateBuiltInFunctionCall(const TFunction &func,
TIntermSequence *arguments);
static TIntermAggregate *CreateConstructor(const TType &type,
TOperator op,
TIntermSequence *arguments);
static TIntermAggregate *Create(const TType &type, TOperator op, TIntermSequence *arguments);
~TIntermAggregate() {} ~TIntermAggregate() {}
// Note: only supported for nodes that can be a part of an expression. // Note: only supported for nodes that can be a part of an expression.
TIntermTyped *deepCopy() const override { return new TIntermAggregate(*this); } TIntermTyped *deepCopy() const override { return new TIntermAggregate(*this); }
TIntermAggregate *shallowCopy() const;
TIntermAggregate *getAsAggregate() override { return this; } TIntermAggregate *getAsAggregate() override { return this; }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override; bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
...@@ -619,10 +638,6 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -619,10 +638,6 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; } TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; } const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
// Used for built-in functions under EOpCallBuiltInFunction. The function name in the symbol
// info needs to be set before calling this.
void setBuiltInFunctionPrecision();
protected: protected:
TIntermSequence mArguments; TIntermSequence mArguments;
...@@ -635,6 +650,8 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -635,6 +650,8 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TFunctionSymbolInfo mFunctionInfo; TFunctionSymbolInfo mFunctionInfo;
private: private:
TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments);
TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private! TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private!
void setTypePrecisionAndQualifier(const TType &type); void setTypePrecisionAndQualifier(const TType &type);
...@@ -647,6 +664,10 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -647,6 +664,10 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
// Returns true if precision was set according to special rules for this built-in. // Returns true if precision was set according to special rules for this built-in.
bool setPrecisionForSpecialBuiltInOp(); bool setPrecisionForSpecialBuiltInOp();
// Used for built-in functions under EOpCallBuiltInFunction. The function name in the symbol
// info needs to be set before calling this.
void setBuiltInFunctionPrecision();
}; };
// A list of statements. Either the root node which contains declarations and function definitions, // A list of statements. Either the root node which contains declarations and function definitions,
...@@ -678,7 +699,10 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas ...@@ -678,7 +699,10 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas
public: public:
// TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor // TODO(oetuaho@nvidia.com): See if TFunctionSymbolInfo could be added to constructor
// parameters. // parameters.
TIntermFunctionPrototype(const TType &type) : TIntermTyped(type) {} TIntermFunctionPrototype(const TType &type, const TSymbolUniqueId &id)
: TIntermTyped(type), mFunctionInfo(id)
{
}
~TIntermFunctionPrototype() {} ~TIntermFunctionPrototype() {}
TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; } TIntermFunctionPrototype *getAsFunctionPrototypeNode() override { return this; }
...@@ -967,6 +991,20 @@ class TIntermTraverser : angle::NonCopyable ...@@ -967,6 +991,20 @@ class TIntermTraverser : angle::NonCopyable
// Start creating temporary symbols from the given temporary symbol index + 1. // Start creating temporary symbols from the given temporary symbol index + 1.
void useTemporaryIndex(unsigned int *temporaryIndex); void useTemporaryIndex(unsigned int *temporaryIndex);
static TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(
const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId);
static TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(
const TType &returnType,
const char *name,
TIntermBlock *functionBody,
const TSymbolUniqueId &functionId);
static TIntermAggregate *CreateInternalFunctionCallNode(const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId,
TIntermSequence *arguments);
protected: protected:
// Should only be called from traverse*() functions // Should only be called from traverse*() functions
void incrementDepth(TIntermNode *current) void incrementDepth(TIntermNode *current)
...@@ -1112,6 +1150,8 @@ class TIntermTraverser : angle::NonCopyable ...@@ -1112,6 +1150,8 @@ class TIntermTraverser : angle::NonCopyable
std::vector<NodeInsertMultipleEntry> mInsertions; std::vector<NodeInsertMultipleEntry> mInsertions;
private: private:
static TName GetInternalFunctionName(const char *name);
// To replace a single node with another on the parent node // To replace a single node with another on the parent node
struct NodeUpdateEntry struct NodeUpdateEntry
{ {
...@@ -1195,7 +1235,7 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -1195,7 +1235,7 @@ class TLValueTrackingTraverser : public TIntermTraverser
bool operatorRequiresLValue() const { return mOperatorRequiresLValue; } bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
// Add a function encountered during traversal to the function map. // Add a function encountered during traversal to the function map.
void addToFunctionMap(const TName &name, TIntermSequence *paramSequence); void addToFunctionMap(const TSymbolUniqueId &id, TIntermSequence *paramSequence);
// Return true if the prototype or definition of the function being called has been encountered // Return true if the prototype or definition of the function being called has been encountered
// during traversal. // during traversal.
...@@ -1211,20 +1251,8 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -1211,20 +1251,8 @@ class TLValueTrackingTraverser : public TIntermTraverser
bool mOperatorRequiresLValue; bool mOperatorRequiresLValue;
bool mInFunctionCallOutParameter; bool mInFunctionCallOutParameter;
struct TNameComparator // Map from function symbol id values to their parameter sequences
{ TMap<int, TIntermSequence *> mFunctionMap;
bool operator()(const TName &a, const TName &b) const
{
int compareResult = a.getString().compare(b.getString());
if (compareResult != 0)
return compareResult < 0;
// Internal functions may have same names as non-internal functions.
return !a.isInternal() && b.isInternal();
}
};
// Map from mangled function names to their parameter sequences
TMap<TName, TIntermSequence *, TNameComparator> mFunctionMap;
const TSymbolTable &mSymbolTable; const TSymbolTable &mSymbolTable;
const int mShaderVersion; const int mShaderVersion;
......
...@@ -228,22 +228,23 @@ void TIntermTraverser::nextTemporaryIndex() ...@@ -228,22 +228,23 @@ void TIntermTraverser::nextTemporaryIndex()
++(*mTemporaryIndex); ++(*mTemporaryIndex);
} }
void TLValueTrackingTraverser::addToFunctionMap(const TName &name, TIntermSequence *paramSequence) void TLValueTrackingTraverser::addToFunctionMap(const TSymbolUniqueId &id,
TIntermSequence *paramSequence)
{ {
mFunctionMap[name] = paramSequence; mFunctionMap[id.get()] = paramSequence;
} }
bool TLValueTrackingTraverser::isInFunctionMap(const TIntermAggregate *callNode) const bool TLValueTrackingTraverser::isInFunctionMap(const TIntermAggregate *callNode) const
{ {
ASSERT(callNode->getOp() == EOpCallFunctionInAST); ASSERT(callNode->getOp() == EOpCallFunctionInAST);
return (mFunctionMap.find(callNode->getFunctionSymbolInfo()->getNameObj()) != return (mFunctionMap.find(callNode->getFunctionSymbolInfo()->getId().get()) !=
mFunctionMap.end()); mFunctionMap.end());
} }
TIntermSequence *TLValueTrackingTraverser::getFunctionParameters(const TIntermAggregate *callNode) TIntermSequence *TLValueTrackingTraverser::getFunctionParameters(const TIntermAggregate *callNode)
{ {
ASSERT(isInFunctionMap(callNode)); ASSERT(isInFunctionMap(callNode));
return mFunctionMap[callNode->getFunctionSymbolInfo()->getNameObj()]; return mFunctionMap[callNode->getFunctionSymbolInfo()->getId().get()];
} }
void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter) void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
...@@ -623,7 +624,7 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -623,7 +624,7 @@ void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node) void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{ {
TIntermSequence *sequence = node->getSequence(); TIntermSequence *sequence = node->getSequence();
addToFunctionMap(node->getFunctionSymbolInfo()->getNameObj(), sequence); addToFunctionMap(node->getFunctionSymbolInfo()->getId(), sequence);
TIntermTraverser::traverseFunctionPrototype(node); TIntermTraverser::traverseFunctionPrototype(node);
} }
......
...@@ -2490,7 +2490,8 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction( ...@@ -2490,7 +2490,8 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction(
const TSourceLoc &location, const TSourceLoc &location,
bool insertParametersToSymbolTable) bool insertParametersToSymbolTable)
{ {
TIntermFunctionPrototype *prototype = new TIntermFunctionPrototype(function.getReturnType()); TIntermFunctionPrototype *prototype =
new TIntermFunctionPrototype(function.getReturnType(), TSymbolUniqueId(function));
// TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could // TODO(oetuaho@nvidia.com): Instead of converting the function information here, the node could
// point to the data that already exists in the symbol table. // point to the data that already exists in the symbol table.
prototype->getFunctionSymbolInfo()->setFromFunction(function); prototype->getFunctionSymbolInfo()->setFromFunction(function);
...@@ -2803,9 +2804,8 @@ TIntermTyped *TParseContext::addConstructor(TIntermSequence *arguments, ...@@ -2803,9 +2804,8 @@ TIntermTyped *TParseContext::addConstructor(TIntermSequence *arguments,
return TIntermTyped::CreateZero(type); return TIntermTyped::CreateZero(type);
} }
TIntermAggregate *constructorNode = new TIntermAggregate(type, op, arguments); TIntermAggregate *constructorNode = TIntermAggregate::CreateConstructor(type, op, arguments);
constructorNode->setLine(line); constructorNode->setLine(line);
ASSERT(constructorNode->isConstructor());
TIntermTyped *constConstructor = TIntermTyped *constConstructor =
intermediate.foldAggregateBuiltIn(constructorNode, mDiagnostics); intermediate.foldAggregateBuiltIn(constructorNode, mDiagnostics);
...@@ -4533,7 +4533,7 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall, ...@@ -4533,7 +4533,7 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall,
else else
{ {
TIntermAggregate *callNode = TIntermAggregate *callNode =
new TIntermAggregate(fnCandidate->getReturnType(), op, arguments); TIntermAggregate::Create(fnCandidate->getReturnType(), op, arguments);
callNode->setLine(loc); callNode->setLine(loc);
// Some built-in functions have out parameters too. // Some built-in functions have out parameters too.
...@@ -4561,19 +4561,13 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall, ...@@ -4561,19 +4561,13 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall,
// This needs to happen after the function info including name is set. // This needs to happen after the function info including name is set.
if (builtIn) if (builtIn)
{ {
callNode = new TIntermAggregate(fnCandidate->getReturnType(), callNode = TIntermAggregate::CreateBuiltInFunctionCall(*fnCandidate, arguments);
EOpCallBuiltInFunction, arguments);
// Note that name needs to be set before texture function type is determined.
callNode->getFunctionSymbolInfo()->setFromFunction(*fnCandidate);
callNode->setBuiltInFunctionPrecision();
checkTextureOffsetConst(callNode); checkTextureOffsetConst(callNode);
checkImageMemoryAccessForBuiltinFunctions(callNode); checkImageMemoryAccessForBuiltinFunctions(callNode);
} }
else else
{ {
callNode = new TIntermAggregate(fnCandidate->getReturnType(), callNode = TIntermAggregate::CreateFunctionCall(*fnCandidate, arguments);
EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(*fnCandidate);
checkImageMemoryAccessForUserDefinedFunctions(fnCandidate, callNode); checkImageMemoryAccessForUserDefinedFunctions(fnCandidate, callNode);
} }
......
...@@ -20,7 +20,7 @@ namespace sh ...@@ -20,7 +20,7 @@ namespace sh
namespace namespace
{ {
TName GetIndexFunctionName(const TType &type, bool write) std::string GetIndexFunctionName(const TType &type, bool write)
{ {
TInfoSinkBase nameSink; TInfoSinkBase nameSink;
nameSink << "dyn_index_"; nameSink << "dyn_index_";
...@@ -53,12 +53,7 @@ TName GetIndexFunctionName(const TType &type, bool write) ...@@ -53,12 +53,7 @@ TName GetIndexFunctionName(const TType &type, bool write)
} }
nameSink << type.getNominalSize(); nameSink << type.getNominalSize();
} }
TString nameString = TFunction::mangleName(nameSink.c_str()); return nameSink.str();
TName name(nameString);
// TODO(oetuaho@nvidia.com): would be better to have the parameter types in the mangled name as
// well.
name.setInternal(true);
return name;
} }
TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier) TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
...@@ -115,7 +110,7 @@ TIntermTyped *EnsureSignedInt(TIntermTyped *node) ...@@ -115,7 +110,7 @@ TIntermTyped *EnsureSignedInt(TIntermTyped *node)
TIntermSequence *arguments = new TIntermSequence(); TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(node); arguments->push_back(node);
return new TIntermAggregate(TType(EbtInt), EOpConstructInt, arguments); return TIntermAggregate::CreateConstructor(TType(EbtInt), EOpConstructInt, arguments);
} }
TType GetFieldType(const TType &indexedType) TType GetFieldType(const TType &indexedType)
...@@ -175,7 +170,9 @@ TType GetFieldType(const TType &indexedType) ...@@ -175,7 +170,9 @@ TType GetFieldType(const TType &indexedType)
// base[1] = value; // base[1] = value;
// } // }
// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write) TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type,
bool write,
const TSymbolUniqueId &functionId)
{ {
ASSERT(!type.isArray()); ASSERT(!type.isArray());
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
...@@ -195,16 +192,15 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write) ...@@ -195,16 +192,15 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, bool write)
numCases = type.getNominalSize(); numCases = type.getNominalSize();
} }
TIntermFunctionPrototype *prototypeNode = nullptr; TType returnType(EbtVoid);
if (write) if (!write)
{
prototypeNode = new TIntermFunctionPrototype(TType(EbtVoid));
}
else
{ {
prototypeNode = new TIntermFunctionPrototype(fieldType); returnType = fieldType;
} }
prototypeNode->getFunctionSymbolInfo()->setNameObj(GetIndexFunctionName(type, write));
std::string functionName = GetIndexFunctionName(type, write);
TIntermFunctionPrototype *prototypeNode = TIntermTraverser::CreateInternalFunctionPrototypeNode(
returnType, functionName.c_str(), functionId);
TQualifier baseQualifier = EvqInOut; TQualifier baseQualifier = EvqInOut;
if (!write) if (!write)
...@@ -305,10 +301,11 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser ...@@ -305,10 +301,11 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
bool usedTreeInsertion() const { return mUsedTreeInsertion; } bool usedTreeInsertion() const { return mUsedTreeInsertion; }
protected: protected:
// Sets of types that are indexed. Note that these can not store multiple variants // Maps of types that are indexed to the indexing function ids used for them. Note that these
// of the same type with different precisions - only one precision gets stored. // can not store multiple variants of the same type with different precisions - only one
std::set<TType> mIndexedVecAndMatrixTypes; // precision gets stored.
std::set<TType> mWrittenVecAndMatrixTypes; std::map<TType, TSymbolUniqueId> mIndexedVecAndMatrixTypes;
std::map<TType, TSymbolUniqueId> mWrittenVecAndMatrixTypes;
bool mUsedTreeInsertion; bool mUsedTreeInsertion;
...@@ -332,49 +329,51 @@ void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root) ...@@ -332,49 +329,51 @@ void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
TIntermBlock *rootBlock = root->getAsBlock(); TIntermBlock *rootBlock = root->getAsBlock();
ASSERT(rootBlock != nullptr); ASSERT(rootBlock != nullptr);
TIntermSequence insertions; TIntermSequence insertions;
for (TType type : mIndexedVecAndMatrixTypes) for (auto &type : mIndexedVecAndMatrixTypes)
{ {
insertions.push_back(GetIndexFunctionDefinition(type, false)); insertions.push_back(GetIndexFunctionDefinition(type.first, false, type.second));
} }
for (TType type : mWrittenVecAndMatrixTypes) for (auto &type : mWrittenVecAndMatrixTypes)
{ {
insertions.push_back(GetIndexFunctionDefinition(type, true)); insertions.push_back(GetIndexFunctionDefinition(type.first, true, type.second));
} }
mInsertions.push_back(NodeInsertMultipleEntry(rootBlock, 0, insertions, TIntermSequence())); mInsertions.push_back(NodeInsertMultipleEntry(rootBlock, 0, insertions, TIntermSequence()));
} }
// Create a call to dyn_index_*() based on an indirect indexing op node // Create a call to dyn_index_*() based on an indirect indexing op node
TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node, TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
TIntermTyped *indexedNode, TIntermTyped *index,
TIntermTyped *index) const TSymbolUniqueId &functionId)
{ {
ASSERT(node->getOp() == EOpIndexIndirect); ASSERT(node->getOp() == EOpIndexIndirect);
TIntermSequence *arguments = new TIntermSequence(); TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(indexedNode); arguments->push_back(node->getLeft());
arguments->push_back(index); arguments->push_back(index);
TType fieldType = GetFieldType(indexedNode->getType()); TType fieldType = GetFieldType(node->getLeft()->getType());
TIntermAggregate *indexingCall = std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), false);
new TIntermAggregate(fieldType, EOpCallFunctionInAST, arguments); TIntermAggregate *indexingCall = TIntermTraverser::CreateInternalFunctionCallNode(
fieldType, functionName.c_str(), functionId, arguments);
indexingCall->setLine(node->getLine()); indexingCall->setLine(node->getLine());
indexingCall->getFunctionSymbolInfo()->setNameObj(
GetIndexFunctionName(indexedNode->getType(), false));
return indexingCall; return indexingCall;
} }
TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node, TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
TIntermTyped *index, TIntermTyped *index,
TIntermTyped *writtenValue) TIntermTyped *writtenValue,
const TSymbolUniqueId &functionId)
{ {
// Deep copy the left node so that two pointers to the same node don't end up in the tree. ASSERT(node->getOp() == EOpIndexIndirect);
TIntermNode *leftCopy = node->getLeft()->deepCopy(); TIntermSequence *arguments = new TIntermSequence();
ASSERT(leftCopy != nullptr && leftCopy->getAsTyped() != nullptr); // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
TIntermAggregate *indexedWriteCall = arguments->push_back(node->getLeft()->deepCopy());
CreateIndexFunctionCall(node, leftCopy->getAsTyped(), index); arguments->push_back(index->deepCopy());
indexedWriteCall->getFunctionSymbolInfo()->setNameObj( arguments->push_back(writtenValue);
GetIndexFunctionName(node->getLeft()->getType(), true));
indexedWriteCall->setType(TType(EbtVoid)); std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), true);
indexedWriteCall->getSequence()->push_back(writtenValue); TIntermAggregate *indexedWriteCall = TIntermTraverser::CreateInternalFunctionCallNode(
TType(EbtVoid), functionName.c_str(), functionId, arguments);
indexedWriteCall->setLine(node->getLine());
return indexedWriteCall; return indexedWriteCall;
} }
...@@ -415,8 +414,16 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -415,8 +414,16 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write); ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
#endif #endif
TType type = node->getLeft()->getType(); const TType &type = node->getLeft()->getType();
mIndexedVecAndMatrixTypes.insert(type); TSymbolUniqueId indexingFunctionId;
if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
{
mIndexedVecAndMatrixTypes[type] = indexingFunctionId;
}
else
{
indexingFunctionId = mIndexedVecAndMatrixTypes[type];
}
if (write) if (write)
{ {
...@@ -450,7 +457,15 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -450,7 +457,15 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
// only writes it and doesn't need the previous value. http://anglebug.com/1116 // only writes it and doesn't need the previous value. http://anglebug.com/1116
mWrittenVecAndMatrixTypes.insert(type); TSymbolUniqueId indexedWriteFunctionId;
if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
{
mWrittenVecAndMatrixTypes[type] = indexedWriteFunctionId;
}
else
{
indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type];
}
TType fieldType = GetFieldType(type); TType fieldType = GetFieldType(type);
TIntermSequence insertionsBefore; TIntermSequence insertionsBefore;
...@@ -462,19 +477,19 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -462,19 +477,19 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
initIndex->setLine(node->getLine()); initIndex->setLine(node->getLine());
insertionsBefore.push_back(initIndex); insertionsBefore.push_back(initIndex);
TIntermAggregate *indexingCall = CreateIndexFunctionCall(
node, node->getLeft(), createTempSymbol(indexInitializer->getType()));
// Create a node for referring to the index after the nextTemporaryIndex() call // Create a node for referring to the index after the nextTemporaryIndex() call
// below. // below.
TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType()); TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
TIntermAggregate *indexingCall =
CreateIndexFunctionCall(node, tempIndex, indexingFunctionId);
nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
// field value. // field value.
insertionsBefore.push_back(createTempInitDeclaration(indexingCall)); insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
TIntermAggregate *indexedWriteCall = TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
CreateIndexedWriteFunctionCall(node, tempIndex, createTempSymbol(fieldType)); node, tempIndex, createTempSymbol(fieldType), indexedWriteFunctionId);
insertionsAfter.push_back(indexedWriteCall); insertionsAfter.push_back(indexedWriteCall);
insertStatementsInParentBlock(insertionsBefore, insertionsAfter); insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
queueReplacement(node, createTempSymbol(fieldType), OriginalNode::IS_DROPPED); queueReplacement(node, createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
...@@ -489,7 +504,7 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -489,7 +504,7 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// If the index_expr is unsigned, we'll convert it to signed. // If the index_expr is unsigned, we'll convert it to signed.
ASSERT(!mRemoveIndexSideEffectsInSubtree); ASSERT(!mRemoveIndexSideEffectsInSubtree);
TIntermAggregate *indexingCall = CreateIndexFunctionCall( TIntermAggregate *indexingCall = CreateIndexFunctionCall(
node, node->getLeft(), EnsureSignedInt(node->getRight())); node, EnsureSignedInt(node->getRight()), indexingFunctionId);
queueReplacement(node, indexingCall, OriginalNode::IS_DROPPED); queueReplacement(node, indexingCall, OriginalNode::IS_DROPPED);
} }
} }
......
...@@ -89,8 +89,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -89,8 +89,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
16, node->getFunctionSymbolInfo()->getName().length() - 20); 16, node->getFunctionSymbolInfo()->getName().length() - 20);
TString newName = "texelFetch" + newArgs; TString newName = "texelFetch" + newArgs;
TSymbol *texelFetchSymbol = symbolTable->findBuiltIn(newName, shaderVersion); TSymbol *texelFetchSymbol = symbolTable->findBuiltIn(newName, shaderVersion);
ASSERT(texelFetchSymbol); ASSERT(texelFetchSymbol && texelFetchSymbol->isFunction());
int uniqueId = texelFetchSymbol->getUniqueId();
// Create new node that represents the call of function texelFetch. // Create new node that represents the call of function texelFetch.
// Its argument list will be: texelFetch(sampler, Position+offset, lod). // Its argument list will be: texelFetch(sampler, Position+offset, lod).
...@@ -117,8 +116,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -117,8 +116,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
TIntermTyped *zeroNode = TIntermTyped::CreateZero(TType(EbtInt)); TIntermTyped *zeroNode = TIntermTyped::CreateZero(TType(EbtInt));
constructOffsetIvecArguments->push_back(zeroNode); constructOffsetIvecArguments->push_back(zeroNode);
offsetNode = new TIntermAggregate(texCoordNode->getType(), EOpConstructIVec3, offsetNode = TIntermAggregate::CreateConstructor(texCoordNode->getType(), EOpConstructIVec3,
constructOffsetIvecArguments); constructOffsetIvecArguments);
offsetNode->setLine(texCoordNode->getLine()); offsetNode->setLine(texCoordNode->getLine());
} }
else else
...@@ -136,10 +135,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -136,10 +135,8 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
ASSERT(texelFetchArguments->size() == 3u); ASSERT(texelFetchArguments->size() == 3u);
TIntermAggregate *texelFetchNode = TIntermAggregate *texelFetchNode = TIntermAggregate::CreateBuiltInFunctionCall(
new TIntermAggregate(node->getType(), EOpCallBuiltInFunction, texelFetchArguments); *static_cast<const TFunction *>(texelFetchSymbol), texelFetchArguments);
texelFetchNode->getFunctionSymbolInfo()->setName(newName);
texelFetchNode->getFunctionSymbolInfo()->setId(uniqueId);
texelFetchNode->setLine(node->getLine()); texelFetchNode->setLine(node->getLine());
// Replace the old node by this new node. // Replace the old node by this new node.
......
...@@ -55,16 +55,6 @@ TIntermBinary *CopyAssignmentNode(TIntermBinary *node) ...@@ -55,16 +55,6 @@ TIntermBinary *CopyAssignmentNode(TIntermBinary *node)
return new TIntermBinary(node->getOp(), node->getLeft(), node->getRight()); return new TIntermBinary(node->getOp(), node->getLeft(), node->getRight());
} }
// Performs a shallow copy of a constructor/function call node.
TIntermAggregate *CopyAggregateNode(TIntermAggregate *node)
{
TIntermSequence *copySeq = new TIntermSequence();
copySeq->insert(copySeq->begin(), node->getSequence()->begin(), node->getSequence()->end());
TIntermAggregate *copyNode = new TIntermAggregate(node->getType(), node->getOp(), copySeq);
*copyNode->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
return copyNode;
}
bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node) bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
if (mFoundArrayExpression) if (mFoundArrayExpression)
...@@ -104,7 +94,7 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate ...@@ -104,7 +94,7 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate
mFoundArrayExpression = true; mFoundArrayExpression = true;
TIntermSequence insertions; TIntermSequence insertions;
insertions.push_back(createTempInitDeclaration(CopyAggregateNode(node))); insertions.push_back(createTempInitDeclaration(node->shallowCopy()));
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
queueReplacement(node, createTempSymbol(node->getType()), OriginalNode::IS_DROPPED); queueReplacement(node, createTempSymbol(node->getType()), OriginalNode::IS_DROPPED);
......
...@@ -26,6 +26,19 @@ namespace sh ...@@ -26,6 +26,19 @@ namespace sh
int TSymbolTable::uniqueIdCounter = 0; int TSymbolTable::uniqueIdCounter = 0;
TSymbolUniqueId::TSymbolUniqueId() : mId(TSymbolTable::nextUniqueId())
{
}
TSymbolUniqueId::TSymbolUniqueId(const TSymbol &symbol) : mId(symbol.getUniqueId())
{
}
int TSymbolUniqueId::get() const
{
return mId;
}
TSymbol::TSymbol(const TString *n) : uniqueId(TSymbolTable::nextUniqueId()), name(n) TSymbol::TSymbol(const TString *n) : uniqueId(TSymbolTable::nextUniqueId()), name(n)
{ {
} }
......
...@@ -41,6 +41,22 @@ ...@@ -41,6 +41,22 @@
namespace sh namespace sh
{ {
// Encapsulates a unique id for a symbol.
class TSymbolUniqueId
{
public:
POOL_ALLOCATOR_NEW_DELETE();
TSymbolUniqueId();
TSymbolUniqueId(const TSymbol &symbol);
TSymbolUniqueId(const TSymbolUniqueId &) = default;
TSymbolUniqueId &operator=(const TSymbolUniqueId &) = default;
int get() const;
private:
int mId;
};
// Symbol base class. (Can build functions or variables out of these...) // Symbol base class. (Can build functions or variables out of these...)
class TSymbol : angle::NonCopyable class TSymbol : angle::NonCopyable
{ {
......
...@@ -17,7 +17,7 @@ void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *in ...@@ -17,7 +17,7 @@ void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *in
{ {
const char *internal = info->getNameObj().isInternal() ? " (internal function)" : ""; const char *internal = info->getNameObj().isInternal() ? " (internal function)" : "";
out << str << internal << ": " << info->getNameObj().getString() << " (symbol id " out << str << internal << ": " << info->getNameObj().getString() << " (symbol id "
<< info->getId() << ")"; << info->getId().get() << ")";
} }
// //
......
...@@ -189,7 +189,7 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode) ...@@ -189,7 +189,7 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode)
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
TIntermAggregate *original = TIntermAggregate *original =
new TIntermAggregate(originalSeq->at(0)->getAsTyped()->getType(), EOpMix, originalSeq); TIntermAggregate::Create(originalSeq->at(0)->getAsTyped()->getType(), EOpMix, originalSeq);
original->setLine(getTestSourceLoc()); original->setLine(getTestSourceLoc());
TIntermTyped *copyTyped = original->deepCopy(); TIntermTyped *copyTyped = original->deepCopy();
......
...@@ -69,7 +69,7 @@ TEST_F(PruneUnusedFunctionsTest, UnimplementedPrototype) ...@@ -69,7 +69,7 @@ TEST_F(PruneUnusedFunctionsTest, UnimplementedPrototype)
EXPECT_TRUE(foundInCode("main(", 1)); EXPECT_TRUE(foundInCode("main(", 1));
} }
// Check that used functions are not prunued (duh) // Check that used functions are not pruned (duh)
TEST_F(PruneUnusedFunctionsTest, UsedFunction) TEST_F(PruneUnusedFunctionsTest, UsedFunction)
{ {
const std::string &shaderString = const std::string &shaderString =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment