Commit b6af22b5 by Olli Etuaho Committed by Commit Bot

Store TVariable* in TIntermSymbol instead of storing id

This is an intermediate step to only storing a TVariable * in TIntermSymbol instead of copying the name. This makes it possible to get a constant value out of a TIntermSymbol without doing a symbol table lookup. BUG=angleproject:2267 TEST=angle_unittests Change-Id: Ibff588241a4ad4ac330063296273288b20a072c9 Reviewed-on: https://chromium-review.googlesource.com/829142 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org>
parent 755a9317
...@@ -281,7 +281,7 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position, ...@@ -281,7 +281,7 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position,
} }
TIntermSymbol::TIntermSymbol(const TVariable *variable) TIntermSymbol::TIntermSymbol(const TVariable *variable)
: TIntermTyped(variable->getType()), mId(variable->uniqueId()), mSymbol(variable->name()) : TIntermTyped(variable->getType()), mVariable(variable), mSymbol(variable->name())
{ {
if (variable->symbolType() == SymbolType::AngleInternal) if (variable->symbolType() == SymbolType::AngleInternal)
{ {
...@@ -289,6 +289,11 @@ TIntermSymbol::TIntermSymbol(const TVariable *variable) ...@@ -289,6 +289,11 @@ TIntermSymbol::TIntermSymbol(const TVariable *variable)
} }
} }
const TSymbolUniqueId &TIntermSymbol::uniqueId() const
{
return mVariable->uniqueId();
}
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func, TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func,
TIntermSequence *arguments) TIntermSequence *arguments)
{ {
......
...@@ -263,20 +263,20 @@ class TIntermSymbol : public TIntermTyped ...@@ -263,20 +263,20 @@ class TIntermSymbol : public TIntermTyped
bool hasSideEffects() const override { return false; } bool hasSideEffects() const override { return false; }
int getId() const { return mId.get(); } const TSymbolUniqueId &uniqueId() const;
const TString &getSymbol() const { return mSymbol.getString(); } const TString &getSymbol() const { return mSymbol.getString(); }
const TName &getName() const { return mSymbol; } const TName &getName() const { return mSymbol; }
const TVariable &variable() const { return *mVariable; }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
TIntermSymbol *getAsSymbolNode() override { return this; } TIntermSymbol *getAsSymbolNode() override { return this; }
bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; } bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; }
protected:
const TSymbolUniqueId mId;
TName mSymbol;
private: private:
TIntermSymbol(const TIntermSymbol &) = default; // Note: not deleted, just private! TIntermSymbol(const TIntermSymbol &) = default; // Note: not deleted, just private!
const TVariable *const mVariable; // Guaranteed to be non-null
TName mSymbol;
}; };
// A Raw node stores raw code, that the translator will insert verbatim // A Raw node stores raw code, that the translator will insert verbatim
......
...@@ -2408,7 +2408,7 @@ bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node) ...@@ -2408,7 +2408,7 @@ bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
{ {
TIntermBinary *test = node->getCondition()->getAsBinaryNode(); TIntermBinary *test = node->getCondition()->getAsBinaryNode();
if (test && test->getLeft()->getAsSymbolNode()->getId() == index->getId()) if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
{ {
TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion(); TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
......
...@@ -82,7 +82,7 @@ void TOutputTraverser::visitSymbol(TIntermSymbol *node) ...@@ -82,7 +82,7 @@ void TOutputTraverser::visitSymbol(TIntermSymbol *node)
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, mDepth);
mOut << "'" << node->getSymbol() << "' "; mOut << "'" << node->getSymbol() << "' ";
mOut << "(symbol id " << node->getId() << ") "; mOut << "(symbol id " << node->uniqueId().get() << ") ";
mOut << "(" << node->getCompleteString() << ")"; mOut << "(" << node->getCompleteString() << ")";
mOut << "\n"; mOut << "\n";
} }
......
...@@ -1996,11 +1996,8 @@ bool TParseContext::executeInitializer(const TSourceLoc &line, ...@@ -1996,11 +1996,8 @@ bool TParseContext::executeInitializer(const TSourceLoc &line,
} }
else if (initializer->getAsSymbolNode()) else if (initializer->getAsSymbolNode())
{ {
const TSymbol *symbol = const TVariable &var = initializer->getAsSymbolNode()->variable();
symbolTable.find(initializer->getAsSymbolNode()->getSymbol(), 0); const TConstantUnion *constArray = var.getConstPointer();
const TVariable *tVar = static_cast<const TVariable *>(symbol);
const TConstantUnion *constArray = tVar->getConstPointer();
if (constArray) if (constArray)
{ {
variable->shareConstPointer(constArray); variable->shareConstPointer(constArray);
......
...@@ -92,10 +92,10 @@ void CollectVariableRefCountsTraverser::visitSymbol(TIntermSymbol *node) ...@@ -92,10 +92,10 @@ void CollectVariableRefCountsTraverser::visitSymbol(TIntermSymbol *node)
{ {
incrementStructTypeRefCount(node->getType()); incrementStructTypeRefCount(node->getType());
auto iter = mSymbolIdRefCounts.find(node->getId()); auto iter = mSymbolIdRefCounts.find(node->uniqueId().get());
if (iter == mSymbolIdRefCounts.end()) if (iter == mSymbolIdRefCounts.end())
{ {
mSymbolIdRefCounts[node->getId()] = 1u; mSymbolIdRefCounts[node->uniqueId().get()] = 1u;
return; return;
} }
++(iter->second); ++(iter->second);
...@@ -234,14 +234,14 @@ bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TInterm ...@@ -234,14 +234,14 @@ bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TInterm
TIntermSymbol *symbolNode = declarator->getAsSymbolNode(); TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
if (symbolNode != nullptr) if (symbolNode != nullptr)
{ {
canRemoveVariable = canRemoveVariable = (*mSymbolIdRefCounts)[symbolNode->uniqueId().get()] == 1u ||
(*mSymbolIdRefCounts)[symbolNode->getId()] == 1u || symbolNode->getSymbol().empty(); symbolNode->getSymbol().empty();
} }
TIntermBinary *initNode = declarator->getAsBinaryNode(); TIntermBinary *initNode = declarator->getAsBinaryNode();
if (initNode != nullptr) if (initNode != nullptr)
{ {
ASSERT(initNode->getLeft()->getAsSymbolNode()); ASSERT(initNode->getLeft()->getAsSymbolNode());
int symbolId = initNode->getLeft()->getAsSymbolNode()->getId(); int symbolId = initNode->getLeft()->getAsSymbolNode()->uniqueId().get();
canRemoveVariable = canRemoveVariable =
(*mSymbolIdRefCounts)[symbolId] == 1u && !initNode->getRight()->hasSideEffects(); (*mSymbolIdRefCounts)[symbolId] == 1u && !initNode->getRight()->hasSideEffects();
} }
...@@ -262,8 +262,8 @@ void RemoveUnreferencedVariablesTraverser::visitSymbol(TIntermSymbol *node) ...@@ -262,8 +262,8 @@ void RemoveUnreferencedVariablesTraverser::visitSymbol(TIntermSymbol *node)
{ {
if (mRemoveReferences) if (mRemoveReferences)
{ {
ASSERT(mSymbolIdRefCounts->find(node->getId()) != mSymbolIdRefCounts->end()); ASSERT(mSymbolIdRefCounts->find(node->uniqueId().get()) != mSymbolIdRefCounts->end());
--(*mSymbolIdRefCounts)[node->getId()]; --(*mSymbolIdRefCounts)[node->uniqueId().get()];
decrementStructTypeRefCount(node->getType()); decrementStructTypeRefCount(node->getType());
} }
......
...@@ -20,9 +20,17 @@ TSymbolUniqueId::TSymbolUniqueId(const TSymbol &symbol) : mId(symbol.uniqueId(). ...@@ -20,9 +20,17 @@ TSymbolUniqueId::TSymbolUniqueId(const TSymbol &symbol) : mId(symbol.uniqueId().
{ {
} }
TSymbolUniqueId::TSymbolUniqueId(const TSymbolUniqueId &) = default;
TSymbolUniqueId &TSymbolUniqueId::operator=(const TSymbolUniqueId &) = default;
int TSymbolUniqueId::get() const int TSymbolUniqueId::get() const
{ {
return mId; return mId;
} }
bool TSymbolUniqueId::operator==(const TSymbolUniqueId &other) const
{
return mId == other.mId;
}
} // namespace sh } // namespace sh
...@@ -22,8 +22,9 @@ class TSymbolUniqueId ...@@ -22,8 +22,9 @@ class TSymbolUniqueId
POOL_ALLOCATOR_NEW_DELETE(); POOL_ALLOCATOR_NEW_DELETE();
explicit TSymbolUniqueId(TSymbolTable *symbolTable); explicit TSymbolUniqueId(TSymbolTable *symbolTable);
explicit TSymbolUniqueId(const TSymbol &symbol); explicit TSymbolUniqueId(const TSymbol &symbol);
TSymbolUniqueId(const TSymbolUniqueId &) = default; TSymbolUniqueId(const TSymbolUniqueId &);
TSymbolUniqueId &operator=(const TSymbolUniqueId &) = default; TSymbolUniqueId &operator=(const TSymbolUniqueId &);
bool operator==(const TSymbolUniqueId &) const;
int get() const; int get() const;
......
...@@ -25,7 +25,7 @@ int GetLoopSymbolId(TIntermLoop *loop) ...@@ -25,7 +25,7 @@ int GetLoopSymbolId(TIntermLoop *loop)
TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode(); TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode(); TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
return symbol->getId(); return symbol->uniqueId().get();
} }
// Traverses a node to check if it represents a constant index expression. // Traverses a node to check if it represents a constant index expression.
...@@ -55,7 +55,7 @@ class ValidateConstIndexExpr : public TIntermTraverser ...@@ -55,7 +55,7 @@ class ValidateConstIndexExpr : public TIntermTraverser
if (mValid) if (mValid)
{ {
bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
symbol->getId()) != mLoopSymbolIds.end(); symbol->uniqueId().get()) != mLoopSymbolIds.end();
mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol; mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
} }
} }
...@@ -165,7 +165,7 @@ void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, con ...@@ -165,7 +165,7 @@ void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, con
bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol) bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
{ {
return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->getId()) != return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
mLoopSymbolIds.end(); mLoopSymbolIds.end();
} }
...@@ -252,7 +252,7 @@ int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node) ...@@ -252,7 +252,7 @@ int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
return -1; return -1;
} }
return symbol->getId(); return symbol->uniqueId().get();
} }
bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId) bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
...@@ -280,7 +280,7 @@ bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int in ...@@ -280,7 +280,7 @@ bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int in
error(binOp->getLine(), "Invalid condition", "for"); error(binOp->getLine(), "Invalid condition", "for");
return false; return false;
} }
if (symbol->getId() != indexSymbolId) if (symbol->uniqueId().get() != indexSymbolId)
{ {
error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str()); error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str());
return false; return false;
...@@ -351,7 +351,7 @@ bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int in ...@@ -351,7 +351,7 @@ bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int in
error(expr->getLine(), "Invalid expression", "for"); error(expr->getLine(), "Invalid expression", "for");
return false; return false;
} }
if (symbol->getId() != indexSymbolId) if (symbol->uniqueId().get() != indexSymbolId)
{ {
error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str()); error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str());
return false; return false;
......
...@@ -74,7 +74,8 @@ class IntermNodeTest : public testing::Test ...@@ -74,7 +74,8 @@ class IntermNodeTest : public testing::Test
ASSERT_NE(nullptr, copy); ASSERT_NE(nullptr, copy);
ASSERT_NE(nullptr, original); ASSERT_NE(nullptr, original);
ASSERT_NE(original, copy); ASSERT_NE(original, copy);
ASSERT_EQ(original->getId(), copy->getId()); ASSERT_EQ(&original->variable(), &copy->variable());
ASSERT_EQ(original->uniqueId(), copy->uniqueId());
ASSERT_EQ(original->getName().getString(), copy->getName().getString()); ASSERT_EQ(original->getName().getString(), copy->getName().getString());
ASSERT_EQ(original->getName().isInternal(), copy->getName().isInternal()); ASSERT_EQ(original->getName().isInternal(), copy->getName().isInternal());
checkTypeEqualWithQualifiers(original->getType(), copy->getType()); checkTypeEqualWithQualifiers(original->getType(), copy->getType());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment