Commit 195be942 by Olli Etuaho Committed by Commit Bot

Always create TVariables for TIntermSymbol nodes

TIntermSymbol nodes are now constructed based on a specific TVariable. This makes sure that all TIntermSymbol nodes that are created to refer to a specific temporary in an AST transform will have consistent data. The TVariable objects are not necessarily added to the symbol table levels - just those variables that can be referred to by their name during parsing need to be reachable through there. In the future this can be taken a step further so that TIntermSymbol nodes just to point to a TVariable instead of duplicating the information. BUG=angleproject:2267 TEST=angle_unittests Change-Id: I4e7bcdb0637cd3b588d3c202ef02f4b7bd7954a1 Reviewed-on: https://chromium-review.googlesource.com/811925 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent f414121d
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <map> #include <map>
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
#include "compiler/translator/SymbolTable.h" #include "compiler/translator/SymbolTable.h"
...@@ -28,14 +29,6 @@ void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to) ...@@ -28,14 +29,6 @@ void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to)
} }
} }
TIntermSymbol *CreateReturnValueSymbol(const TSymbolUniqueId &id, const TType &type)
{
TIntermSymbol *node = new TIntermSymbol(id, "angle_return", type);
node->setInternal(true);
node->getTypePointer()->setQualifier(EvqOut);
return node;
}
TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall, TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall,
TIntermTyped *returnValueTarget) TIntermTyped *returnValueTarget)
{ {
...@@ -70,8 +63,10 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser ...@@ -70,8 +63,10 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
// Set when traversal is inside a function with array return value. // Set when traversal is inside a function with array return value.
TIntermFunctionDefinition *mFunctionWithArrayReturnValue; TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
// Map from function symbol ids to array return value ids. // Map from function symbol ids to array return value variables.
std::map<int, TSymbolUniqueId *> mReturnValueIds; std::map<int, TVariable *> mReturnValueVariables;
const TString *const mReturnValueVariableName;
}; };
void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable) void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable)
...@@ -83,7 +78,9 @@ void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTa ...@@ -83,7 +78,9 @@ void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTa
ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser( ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
TSymbolTable *symbolTable) TSymbolTable *symbolTable)
: TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr) : TIntermTraverser(true, false, true, symbolTable),
mFunctionWithArrayReturnValue(nullptr),
mReturnValueVariableName(NewPoolTString("angle_return"))
{ {
} }
...@@ -114,12 +111,16 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit ...@@ -114,12 +111,16 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit
new TIntermFunctionPrototype(TType(EbtVoid), node->getFunctionSymbolInfo()->getId()); new TIntermFunctionPrototype(TType(EbtVoid), node->getFunctionSymbolInfo()->getId());
CopyAggregateChildren(node, replacement); CopyAggregateChildren(node, replacement);
const TSymbolUniqueId &functionId = node->getFunctionSymbolInfo()->getId(); const TSymbolUniqueId &functionId = node->getFunctionSymbolInfo()->getId();
if (mReturnValueIds.find(functionId.get()) == mReturnValueIds.end()) if (mReturnValueVariables.find(functionId.get()) == mReturnValueVariables.end())
{ {
mReturnValueIds[functionId.get()] = new TSymbolUniqueId(mSymbolTable); TType returnValueVariableType(node->getType());
returnValueVariableType.setQualifier(EvqOut);
mReturnValueVariables[functionId.get()] =
new TVariable(mSymbolTable, mReturnValueVariableName, returnValueVariableType,
SymbolType::AngleInternal);
} }
replacement->getSequence()->push_back( replacement->getSequence()->push_back(
CreateReturnValueSymbol(*mReturnValueIds[functionId.get()], node->getType())); new TIntermSymbol(mReturnValueVariables[functionId.get()]));
*replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo(); *replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo();
replacement->setLine(node->getLine()); replacement->setLine(node->getLine());
...@@ -145,11 +146,21 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter ...@@ -145,11 +146,21 @@ bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TInter
TIntermBlock *parentBlock = getParentNode()->getAsBlock(); TIntermBlock *parentBlock = getParentNode()->getAsBlock();
if (parentBlock) if (parentBlock)
{ {
nextTemporaryId(); // replace
// f();
// with
// type s0[size]; f(s0);
TIntermSequence replacements; TIntermSequence replacements;
replacements.push_back(createTempDeclaration(node->getType()));
TIntermSymbol *returnSymbol = createTempSymbol(node->getType()); // type s0[size];
replacements.push_back(CreateReplacementCall(node, returnSymbol)); TIntermDeclaration *returnValueDeclaration = nullptr;
TVariable *returnValue = DeclareTempVariable(mSymbolTable, node->getType(),
EvqTemporary, &returnValueDeclaration);
replacements.push_back(returnValueDeclaration);
// f(s0);
TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
replacements.push_back(CreateReplacementCall(node, returnValueSymbol));
mMultiReplacements.push_back( mMultiReplacements.push_back(
NodeReplaceWithMultipleEntry(parentBlock, node, replacements)); NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
} }
...@@ -169,10 +180,9 @@ bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBr ...@@ -169,10 +180,9 @@ bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBr
ASSERT(expression != nullptr); ASSERT(expression != nullptr);
const TSymbolUniqueId &functionId = const TSymbolUniqueId &functionId =
mFunctionWithArrayReturnValue->getFunctionSymbolInfo()->getId(); mFunctionWithArrayReturnValue->getFunctionSymbolInfo()->getId();
ASSERT(mReturnValueIds.find(functionId.get()) != mReturnValueIds.end()); ASSERT(mReturnValueVariables.find(functionId.get()) != mReturnValueVariables.end());
const TSymbolUniqueId &returnValueId = *mReturnValueIds[functionId.get()];
TIntermSymbol *returnValueSymbol = TIntermSymbol *returnValueSymbol =
CreateReturnValueSymbol(returnValueId, expression->getType()); new TIntermSymbol(mReturnValueVariables[functionId.get()]);
TIntermBinary *replacementAssignment = TIntermBinary *replacementAssignment =
new TIntermBinary(EOpAssign, returnValueSymbol, expression); new TIntermBinary(EOpAssign, returnValueSymbol, expression);
replacementAssignment->setLine(expression->getLine()); replacementAssignment->setLine(expression->getLine());
......
...@@ -32,7 +32,7 @@ class ReplaceVariableTraverser : public TIntermTraverser ...@@ -32,7 +32,7 @@ class ReplaceVariableTraverser : public TIntermTraverser
void visitSymbol(TIntermSymbol *node) override void visitSymbol(TIntermSymbol *node) override
{ {
TName &name = node->getName(); const TName &name = node->getName();
if (name.getString() == mSymbolName) if (name.getString() == mSymbolName)
{ {
queueReplacement(mNewSymbol->deepCopy(), OriginalNode::IS_DROPPED); queueReplacement(mNewSymbol->deepCopy(), OriginalNode::IS_DROPPED);
...@@ -168,9 +168,11 @@ void DeclareAndInitBuiltinsForInstancedMultiview(TIntermBlock *root, ...@@ -168,9 +168,11 @@ void DeclareAndInitBuiltinsForInstancedMultiview(TIntermBlock *root,
ASSERT(shaderType == GL_VERTEX_SHADER || shaderType == GL_FRAGMENT_SHADER); ASSERT(shaderType == GL_VERTEX_SHADER || shaderType == GL_FRAGMENT_SHADER);
TQualifier viewIDQualifier = (shaderType == GL_VERTEX_SHADER) ? EvqFlatOut : EvqFlatIn; TQualifier viewIDQualifier = (shaderType == GL_VERTEX_SHADER) ? EvqFlatOut : EvqFlatIn;
TIntermSymbol *viewIDSymbol = new TIntermSymbol(symbolTable->nextUniqueId(), "ViewID_OVR", const TString *viewIDVariableName = NewPoolTString("ViewID_OVR");
TType(EbtUInt, EbpHigh, viewIDQualifier)); const TVariable *viewIDVariable =
viewIDSymbol->setInternal(true); new TVariable(symbolTable, viewIDVariableName, TType(EbtUInt, EbpHigh, viewIDQualifier),
SymbolType::AngleInternal);
TIntermSymbol *viewIDSymbol = new TIntermSymbol(viewIDVariable);
DeclareGlobalVariable(root, viewIDSymbol); DeclareGlobalVariable(root, viewIDSymbol);
ReplaceSymbol(root, "gl_ViewID_OVR", viewIDSymbol); ReplaceSymbol(root, "gl_ViewID_OVR", viewIDSymbol);
...@@ -178,9 +180,11 @@ void DeclareAndInitBuiltinsForInstancedMultiview(TIntermBlock *root, ...@@ -178,9 +180,11 @@ void DeclareAndInitBuiltinsForInstancedMultiview(TIntermBlock *root,
{ {
// Replacing gl_InstanceID with InstanceID should happen before adding the initializers of // Replacing gl_InstanceID with InstanceID should happen before adding the initializers of
// InstanceID and ViewID. // InstanceID and ViewID.
TIntermSymbol *instanceIDSymbol = new TIntermSymbol( const TString *instanceIDVariableName = NewPoolTString("InstanceID");
symbolTable->nextUniqueId(), "InstanceID", TType(EbtInt, EbpHigh, EvqGlobal)); const TVariable *instanceIDVariable =
instanceIDSymbol->setInternal(true); new TVariable(symbolTable, instanceIDVariableName, TType(EbtInt, EbpHigh, EvqGlobal),
SymbolType::AngleInternal);
TIntermSymbol *instanceIDSymbol = new TIntermSymbol(instanceIDVariable);
DeclareGlobalVariable(root, instanceIDSymbol); DeclareGlobalVariable(root, instanceIDSymbol);
ReplaceSymbol(root, "gl_InstanceID", instanceIDSymbol); ReplaceSymbol(root, "gl_InstanceID", instanceIDSymbol);
...@@ -197,10 +201,13 @@ void DeclareAndInitBuiltinsForInstancedMultiview(TIntermBlock *root, ...@@ -197,10 +201,13 @@ void DeclareAndInitBuiltinsForInstancedMultiview(TIntermBlock *root,
if (selectView) if (selectView)
{ {
// Add a uniform to switch between side-by-side and layered rendering. // Add a uniform to switch between side-by-side and layered rendering.
const TString *multiviewBaseViewLayerIndexVariableName =
NewPoolTString("multiviewBaseViewLayerIndex");
const TVariable *multiviewBaseViewLayerIndexVariable =
new TVariable(symbolTable, multiviewBaseViewLayerIndexVariableName,
TType(EbtInt, EbpHigh, EvqUniform), SymbolType::AngleInternal);
TIntermSymbol *multiviewBaseViewLayerIndexSymbol = TIntermSymbol *multiviewBaseViewLayerIndexSymbol =
new TIntermSymbol(symbolTable->nextUniqueId(), "multiviewBaseViewLayerIndex", new TIntermSymbol(multiviewBaseViewLayerIndexVariable);
TType(EbtInt, EbpHigh, EvqUniform));
multiviewBaseViewLayerIndexSymbol->setInternal(true);
DeclareGlobalVariable(root, multiviewBaseViewLayerIndexSymbol); DeclareGlobalVariable(root, multiviewBaseViewLayerIndexSymbol);
// Setting a value to gl_ViewportIndex or gl_Layer should happen after ViewID_OVR's // Setting a value to gl_ViewportIndex or gl_Layer should happen after ViewID_OVR's
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <cmath> #include <cmath>
#include <cstdlib> #include <cstdlib>
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
namespace sh namespace sh
...@@ -103,20 +104,20 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -103,20 +104,20 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
} }
// Potential problem case detected, apply workaround. // Potential problem case detected, apply workaround.
nextTemporaryId();
TIntermTyped *lhs = sequence->at(0)->getAsTyped(); TIntermTyped *lhs = sequence->at(0)->getAsTyped();
ASSERT(lhs); ASSERT(lhs);
TIntermDeclaration *init = createTempInitDeclaration(lhs); TIntermDeclaration *lhsVariableDeclaration = nullptr;
TIntermTyped *current = createTempSymbol(lhs->getType()); TVariable *lhsVariable =
DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
insertStatementInParentBlock(init); insertStatementInParentBlock(lhsVariableDeclaration);
// Create a chain of n-1 multiples. // Create a chain of n-1 multiples.
TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
for (int i = 1; i < n; ++i) for (int i = 1; i < n; ++i)
{ {
TIntermBinary *mul = new TIntermBinary(EOpMul, current, createTempSymbol(lhs->getType())); TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
mul->setLine(node->getLine()); mul->setLine(node->getLine());
current = mul; current = mul;
} }
......
...@@ -102,13 +102,12 @@ void AddArrayZeroInitForLoop(const TIntermTyped *initializedNode, ...@@ -102,13 +102,12 @@ void AddArrayZeroInitForLoop(const TIntermTyped *initializedNode,
TSymbolTable *symbolTable) TSymbolTable *symbolTable)
{ {
ASSERT(initializedNode->isArray()); ASSERT(initializedNode->isArray());
TSymbolUniqueId indexSymbol(symbolTable); TVariable *indexVariable = CreateTempVariable(
symbolTable, TType(EbtInt, highPrecisionSupported ? EbpHigh : EbpMedium, EvqTemporary));
TType indexType(EbtInt, highPrecisionSupported ? EbpHigh : EbpMedium); TIntermSymbol *indexSymbolNode = CreateTempSymbolNode(indexVariable);
TIntermSymbol *indexSymbolNode = CreateTempSymbolNode(indexSymbol, indexType, EvqTemporary);
TIntermDeclaration *indexInit = TIntermDeclaration *indexInit =
CreateTempInitDeclarationNode(indexSymbol, CreateZeroNode(indexType), EvqTemporary); CreateTempInitDeclarationNode(indexVariable, CreateZeroNode(indexVariable->getType()));
TIntermConstantUnion *arraySizeNode = CreateIndexNode(initializedNode->getOutermostArraySize()); TIntermConstantUnion *arraySizeNode = CreateIndexNode(initializedNode->getOutermostArraySize());
TIntermBinary *indexSmallerThanSize = TIntermBinary *indexSmallerThanSize =
new TIntermBinary(EOpLessThan, indexSymbolNode->deepCopy(), arraySizeNode); new TIntermBinary(EOpLessThan, indexSymbolNode->deepCopy(), arraySizeNode);
......
...@@ -271,6 +271,15 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position, ...@@ -271,6 +271,15 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position,
return true; return true;
} }
TIntermSymbol::TIntermSymbol(const TVariable *variable)
: TIntermTyped(variable->getType()), mId(variable->uniqueId()), mSymbol(variable->name())
{
if (variable->symbolType() == SymbolType::AngleInternal)
{
mSymbol.setInternal(true);
}
}
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func, TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func,
TIntermSequence *arguments) TIntermSequence *arguments)
{ {
......
...@@ -58,6 +58,7 @@ class TIntermBranch; ...@@ -58,6 +58,7 @@ class TIntermBranch;
class TSymbolTable; class TSymbolTable;
class TFunction; class TFunction;
class TVariable;
// Encapsulate an identifier string and track whether it is coming from the original shader code // Encapsulate an identifier string and track whether it is coming from the original shader code
// (not internal) or from ANGLE (internal). Usually internal names shouldn't be decorated or hashed. // (not internal) or from ANGLE (internal). Usually internal names shouldn't be decorated or hashed.
...@@ -255,13 +256,7 @@ class TIntermBranch : public TIntermNode ...@@ -255,13 +256,7 @@ class TIntermBranch : public TIntermNode
class TIntermSymbol : public TIntermTyped class TIntermSymbol : public TIntermTyped
{ {
public: public:
// if symbol is initialized as symbol(sym), the memory comes from the poolallocator of sym. TIntermSymbol(const TVariable *variable);
// If sym comes from per process globalpoolallocator, then it causes increased memory usage
// per compile it is essential to use "symbol = sym" to assign to symbol
TIntermSymbol(const TSymbolUniqueId &id, const TString &symbol, const TType &type)
: TIntermTyped(type), mId(id), mSymbol(symbol)
{
}
TIntermTyped *deepCopy() const override { return new TIntermSymbol(*this); } TIntermTyped *deepCopy() const override { return new TIntermSymbol(*this); }
...@@ -270,9 +265,6 @@ class TIntermSymbol : public TIntermTyped ...@@ -270,9 +265,6 @@ class TIntermSymbol : public TIntermTyped
int getId() const { return mId.get(); } int getId() const { return mId.get(); }
const TString &getSymbol() const { return mSymbol.getString(); } const TString &getSymbol() const { return mSymbol.getString(); }
const TName &getName() const { return mSymbol; } const TName &getName() const { return mSymbol; }
TName &getName() { return mSymbol; }
void setInternal(bool internal) { mSymbol.setInternal(internal); }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
TIntermSymbol *getAsSymbolNode() override { return this; } TIntermSymbol *getAsSymbolNode() override { return this; }
......
...@@ -170,37 +170,80 @@ TIntermConstantUnion *CreateBoolNode(bool value) ...@@ -170,37 +170,80 @@ TIntermConstantUnion *CreateBoolNode(bool value)
return node; return node;
} }
TIntermSymbol *CreateTempSymbolNode(const TSymbolUniqueId &id, TVariable *CreateTempVariable(TSymbolTable *symbolTable, const TType &type)
const TType &type,
TQualifier qualifier)
{ {
TInfoSinkBase symbolNameOut; ASSERT(symbolTable != nullptr);
symbolNameOut << "s" << id.get(); // TODO(oetuaho): Might be useful to sanitize layout qualifier etc. on the type of the created
TString symbolName = symbolNameOut.c_str(); // variable. This might need to be done in other places as well.
return new TVariable(symbolTable, nullptr, type, SymbolType::AngleInternal);
}
TIntermSymbol *node = new TIntermSymbol(id, symbolName, type); TVariable *CreateTempVariable(TSymbolTable *symbolTable, const TType &type, TQualifier qualifier)
node->setInternal(true); {
ASSERT(symbolTable != nullptr);
if (type.getQualifier() == qualifier)
{
return CreateTempVariable(symbolTable, type);
}
TType typeWithQualifier(type);
typeWithQualifier.setQualifier(qualifier);
return CreateTempVariable(symbolTable, typeWithQualifier);
}
ASSERT(qualifier == EvqTemporary || qualifier == EvqConst || qualifier == EvqGlobal); TIntermSymbol *CreateTempSymbolNode(const TVariable *tempVariable)
node->getTypePointer()->setQualifier(qualifier); {
ASSERT(tempVariable->symbolType() == SymbolType::AngleInternal);
ASSERT(tempVariable->getType().getQualifier() == EvqTemporary ||
tempVariable->getType().getQualifier() == EvqConst ||
tempVariable->getType().getQualifier() == EvqGlobal);
return new TIntermSymbol(tempVariable);
}
// TODO(oetuaho): Might be useful to sanitize layout qualifier etc. on the type of the created TIntermDeclaration *CreateTempDeclarationNode(const TVariable *tempVariable)
// symbol. This might need to be done in other places as well. {
return node; TIntermDeclaration *tempDeclaration = new TIntermDeclaration();
tempDeclaration->appendDeclarator(CreateTempSymbolNode(tempVariable));
return tempDeclaration;
} }
TIntermDeclaration *CreateTempInitDeclarationNode(const TSymbolUniqueId &id, TIntermDeclaration *CreateTempInitDeclarationNode(const TVariable *tempVariable,
TIntermTyped *initializer, TIntermTyped *initializer)
TQualifier qualifier)
{ {
ASSERT(initializer != nullptr); ASSERT(initializer != nullptr);
TIntermSymbol *tempSymbol = CreateTempSymbolNode(id, initializer->getType(), qualifier); TIntermSymbol *tempSymbol = CreateTempSymbolNode(tempVariable);
TIntermDeclaration *tempDeclaration = new TIntermDeclaration(); TIntermDeclaration *tempDeclaration = new TIntermDeclaration();
TIntermBinary *tempInit = new TIntermBinary(EOpInitialize, tempSymbol, initializer); TIntermBinary *tempInit = new TIntermBinary(EOpInitialize, tempSymbol, initializer);
tempDeclaration->appendDeclarator(tempInit); tempDeclaration->appendDeclarator(tempInit);
return tempDeclaration; return tempDeclaration;
} }
TIntermBinary *CreateTempAssignmentNode(const TVariable *tempVariable, TIntermTyped *rightNode)
{
ASSERT(rightNode != nullptr);
TIntermSymbol *tempSymbol = CreateTempSymbolNode(tempVariable);
return new TIntermBinary(EOpAssign, tempSymbol, rightNode);
}
TVariable *DeclareTempVariable(TSymbolTable *symbolTable,
const TType &type,
TQualifier qualifier,
TIntermDeclaration **declarationOut)
{
TVariable *variable = CreateTempVariable(symbolTable, type, qualifier);
*declarationOut = CreateTempDeclarationNode(variable);
return variable;
}
TVariable *DeclareTempVariable(TSymbolTable *symbolTable,
TIntermTyped *initializer,
TQualifier qualifier,
TIntermDeclaration **declarationOut)
{
TVariable *variable = CreateTempVariable(symbolTable, initializer->getType(), qualifier);
*declarationOut = CreateTempInitDeclarationNode(variable, initializer);
return variable;
}
TIntermBlock *EnsureBlock(TIntermNode *node) TIntermBlock *EnsureBlock(TIntermNode *node)
{ {
if (node == nullptr) if (node == nullptr)
...@@ -219,7 +262,7 @@ TIntermSymbol *ReferenceGlobalVariable(const TString &name, const TSymbolTable & ...@@ -219,7 +262,7 @@ TIntermSymbol *ReferenceGlobalVariable(const TString &name, const TSymbolTable &
{ {
TVariable *var = reinterpret_cast<TVariable *>(symbolTable.findGlobal(name)); TVariable *var = reinterpret_cast<TVariable *>(symbolTable.findGlobal(name));
ASSERT(var); ASSERT(var);
return new TIntermSymbol(var->uniqueId(), name, var->getType()); return new TIntermSymbol(var);
} }
TIntermSymbol *ReferenceBuiltInVariable(const TString &name, TIntermSymbol *ReferenceBuiltInVariable(const TString &name,
...@@ -229,7 +272,7 @@ TIntermSymbol *ReferenceBuiltInVariable(const TString &name, ...@@ -229,7 +272,7 @@ TIntermSymbol *ReferenceBuiltInVariable(const TString &name,
const TVariable *var = const TVariable *var =
reinterpret_cast<const TVariable *>(symbolTable.findBuiltIn(name, shaderVersion, true)); reinterpret_cast<const TVariable *>(symbolTable.findBuiltIn(name, shaderVersion, true));
ASSERT(var); ASSERT(var);
return new TIntermSymbol(var->uniqueId(), name, var->getType()); return new TIntermSymbol(var);
} }
TIntermTyped *CreateBuiltInFunctionCallNode(const TString &name, TIntermTyped *CreateBuiltInFunctionCallNode(const TString &name,
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
namespace sh namespace sh
{ {
class TSymbolTable;
class TVariable;
TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TType &returnType, TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TType &returnType,
const char *name, const char *name,
const TSymbolUniqueId &functionId); const TSymbolUniqueId &functionId);
...@@ -30,12 +33,23 @@ TIntermTyped *CreateZeroNode(const TType &type); ...@@ -30,12 +33,23 @@ TIntermTyped *CreateZeroNode(const TType &type);
TIntermConstantUnion *CreateIndexNode(int index); TIntermConstantUnion *CreateIndexNode(int index);
TIntermConstantUnion *CreateBoolNode(bool value); TIntermConstantUnion *CreateBoolNode(bool value);
TIntermSymbol *CreateTempSymbolNode(const TSymbolUniqueId &id, TVariable *CreateTempVariable(TSymbolTable *symbolTable, const TType &type);
const TType &type, TVariable *CreateTempVariable(TSymbolTable *symbolTable, const TType &type, TQualifier qualifier);
TQualifier qualifier);
TIntermDeclaration *CreateTempInitDeclarationNode(const TSymbolUniqueId &id, TIntermSymbol *CreateTempSymbolNode(const TVariable *tempVariable);
TIntermTyped *initializer, TIntermDeclaration *CreateTempDeclarationNode(const TVariable *tempVariable);
TQualifier qualifier); TIntermDeclaration *CreateTempInitDeclarationNode(const TVariable *tempVariable,
TIntermTyped *initializer);
TIntermBinary *CreateTempAssignmentNode(const TVariable *tempVariable, TIntermTyped *rightNode);
TVariable *DeclareTempVariable(TSymbolTable *symbolTable,
const TType &type,
TQualifier qualifier,
TIntermDeclaration **declarationOut);
TVariable *DeclareTempVariable(TSymbolTable *symbolTable,
TIntermTyped *initializer,
TQualifier qualifier,
TIntermDeclaration **declarationOut);
// If the input node is nullptr, return nullptr. // If the input node is nullptr, return nullptr.
// If the input node is a block node, return it. // If the input node is a block node, return it.
......
...@@ -113,8 +113,7 @@ TIntermTraverser::TIntermTraverser(bool preVisit, ...@@ -113,8 +113,7 @@ TIntermTraverser::TIntermTraverser(bool preVisit,
mDepth(-1), mDepth(-1),
mMaxDepth(0), mMaxDepth(0),
mInGlobalScope(true), mInGlobalScope(true),
mSymbolTable(symbolTable), mSymbolTable(symbolTable)
mTemporaryId(nullptr)
{ {
} }
...@@ -177,58 +176,6 @@ void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement) ...@@ -177,58 +176,6 @@ void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement)
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
} }
TIntermSymbol *TIntermTraverser::createTempSymbol(const TType &type, TQualifier qualifier)
{
ASSERT(mTemporaryId != nullptr);
// nextTemporaryId() needs to be called when the code wants to start using another temporary
// symbol.
return CreateTempSymbolNode(*mTemporaryId, type, qualifier);
}
TIntermSymbol *TIntermTraverser::createTempSymbol(const TType &type)
{
return createTempSymbol(type, EvqTemporary);
}
TIntermDeclaration *TIntermTraverser::createTempDeclaration(const TType &type)
{
ASSERT(mTemporaryId != nullptr);
TIntermDeclaration *tempDeclaration = new TIntermDeclaration();
tempDeclaration->appendDeclarator(CreateTempSymbolNode(*mTemporaryId, type, EvqTemporary));
return tempDeclaration;
}
TIntermDeclaration *TIntermTraverser::createTempInitDeclaration(TIntermTyped *initializer,
TQualifier qualifier)
{
ASSERT(mTemporaryId != nullptr);
return CreateTempInitDeclarationNode(*mTemporaryId, initializer, qualifier);
}
TIntermDeclaration *TIntermTraverser::createTempInitDeclaration(TIntermTyped *initializer)
{
return createTempInitDeclaration(initializer, EvqTemporary);
}
TIntermBinary *TIntermTraverser::createTempAssignment(TIntermTyped *rightNode)
{
ASSERT(rightNode != nullptr);
TIntermSymbol *tempSymbol = createTempSymbol(rightNode->getType());
TIntermBinary *assignment = new TIntermBinary(EOpAssign, tempSymbol, rightNode);
return assignment;
}
void TIntermTraverser::nextTemporaryId()
{
ASSERT(mSymbolTable);
if (!mTemporaryId)
{
mTemporaryId = new TSymbolUniqueId(mSymbolTable);
return;
}
*mTemporaryId = TSymbolUniqueId(mSymbolTable);
}
void TLValueTrackingTraverser::addToFunctionMap(const TSymbolUniqueId &id, void TLValueTrackingTraverser::addToFunctionMap(const TSymbolUniqueId &id,
TIntermSequence *paramSequence) TIntermSequence *paramSequence)
{ {
......
...@@ -179,22 +179,6 @@ class TIntermTraverser : angle::NonCopyable ...@@ -179,22 +179,6 @@ class TIntermTraverser : angle::NonCopyable
// Helper to insert a single statement. // Helper to insert a single statement.
void insertStatementInParentBlock(TIntermNode *statement); void insertStatementInParentBlock(TIntermNode *statement);
// Helper to create a temporary symbol node with the given qualifier.
TIntermSymbol *createTempSymbol(const TType &type, TQualifier qualifier);
// Helper to create a temporary symbol node.
TIntermSymbol *createTempSymbol(const TType &type);
// Create a node that declares but doesn't initialize a temporary symbol.
TIntermDeclaration *createTempDeclaration(const TType &type);
// Create a node that initializes the current temporary symbol with initializer. The symbol will
// have the given qualifier.
TIntermDeclaration *createTempInitDeclaration(TIntermTyped *initializer, TQualifier qualifier);
// Create a node that initializes the current temporary symbol with initializer.
TIntermDeclaration *createTempInitDeclaration(TIntermTyped *initializer);
// Create a node that assigns rightNode to the current temporary symbol.
TIntermBinary *createTempAssignment(TIntermTyped *rightNode);
// Increment temporary symbol index.
void nextTemporaryId();
enum class OriginalNode enum class OriginalNode
{ {
BECOMES_CHILD, BECOMES_CHILD,
...@@ -290,8 +274,6 @@ class TIntermTraverser : angle::NonCopyable ...@@ -290,8 +274,6 @@ class TIntermTraverser : angle::NonCopyable
// All the code blocks from the root to the current node's parent during traversal. // All the code blocks from the root to the current node's parent during traversal.
std::vector<ParentBlock> mParentBlockStack; std::vector<ParentBlock> mParentBlockStack;
TSymbolUniqueId *mTemporaryId;
}; };
// Traverser parent class that tracks where a node is a destination of a write operation and so is // Traverser parent class that tracks where a node is a destination of a write operation and so is
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "compiler/translator/RecordConstantPrecision.h" #include "compiler/translator/RecordConstantPrecision.h"
#include "compiler/translator/InfoSink.h" #include "compiler/translator/InfoSink.h"
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
namespace sh namespace sh
...@@ -136,12 +137,10 @@ void RecordConstantPrecisionTraverser::visitConstantUnion(TIntermConstantUnion * ...@@ -136,12 +137,10 @@ void RecordConstantPrecisionTraverser::visitConstantUnion(TIntermConstantUnion *
// Make the constant a precision-qualified named variable to make sure it affects the precision // Make the constant a precision-qualified named variable to make sure it affects the precision
// of the consuming expression. // of the consuming expression.
nextTemporaryId(); TIntermDeclaration *variableDeclaration = nullptr;
TVariable *variable = DeclareTempVariable(mSymbolTable, node, EvqConst, &variableDeclaration);
TIntermSequence insertions; insertStatementInParentBlock(variableDeclaration);
insertions.push_back(createTempInitDeclaration(node, EvqConst)); queueReplacement(CreateTempSymbolNode(variable), OriginalNode::IS_DROPPED);
insertStatementsInParentBlock(insertions);
queueReplacement(createTempSymbol(node->getType()), OriginalNode::IS_DROPPED);
mFoundHigherPrecisionConstant = true; mFoundHigherPrecisionConstant = true;
} }
......
...@@ -58,29 +58,30 @@ std::string GetIndexFunctionName(const TType &type, bool write) ...@@ -58,29 +58,30 @@ std::string GetIndexFunctionName(const TType &type, bool write)
return nameSink.str(); return nameSink.str();
} }
TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier, TSymbolTable *symbolTable) TIntermSymbol *CreateBaseSymbol(const TType &type, TSymbolTable *symbolTable)
{ {
TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), "base", type); TString *baseString = NewPoolTString("base");
symbol->setInternal(true); TVariable *baseVariable =
symbol->getTypePointer()->setQualifier(qualifier); new TVariable(symbolTable, baseString, type, SymbolType::AngleInternal);
return symbol; return new TIntermSymbol(baseVariable);
} }
TIntermSymbol *CreateIndexSymbol(TSymbolTable *symbolTable) TIntermSymbol *CreateIndexSymbol(TSymbolTable *symbolTable)
{ {
TIntermSymbol *symbol = TString *indexString = NewPoolTString("index");
new TIntermSymbol(symbolTable->nextUniqueId(), "index", TType(EbtInt, EbpHigh)); TVariable *indexVariable = new TVariable(
symbol->setInternal(true); symbolTable, indexString, TType(EbtInt, EbpHigh, EvqIn), SymbolType::AngleInternal);
symbol->getTypePointer()->setQualifier(EvqIn); return new TIntermSymbol(indexVariable);
return symbol;
} }
TIntermSymbol *CreateValueSymbol(const TType &type, TSymbolTable *symbolTable) TIntermSymbol *CreateValueSymbol(const TType &type, TSymbolTable *symbolTable)
{ {
TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), "value", type); TString *valueString = NewPoolTString("value");
symbol->setInternal(true); TType valueType(type);
symbol->getTypePointer()->setQualifier(EvqIn); valueType.setQualifier(EvqIn);
return symbol; TVariable *valueVariable =
new TVariable(symbolTable, valueString, valueType, SymbolType::AngleInternal);
return new TIntermSymbol(valueVariable);
} }
TIntermConstantUnion *CreateIntConstantNode(int i) TIntermConstantUnion *CreateIntConstantNode(int i)
...@@ -157,17 +158,12 @@ TType GetFieldType(const TType &indexedType) ...@@ -157,17 +158,12 @@ TType GetFieldType(const TType &indexedType)
// base[1] = value; // base[1] = value;
// } // }
// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
bool write, bool write,
const TSymbolUniqueId &functionId, const TSymbolUniqueId &functionId,
TSymbolTable *symbolTable) TSymbolTable *symbolTable)
{ {
ASSERT(!type.isArray()); ASSERT(!type.isArray());
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
// end up using mediump version of an indexing function for a highp value, if both mediump and
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends.
type.setPrecision(EbpHigh);
TType fieldType = GetFieldType(type); TType fieldType = GetFieldType(type);
int numCases = 0; int numCases = 0;
...@@ -190,10 +186,17 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type, ...@@ -190,10 +186,17 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type,
TIntermFunctionPrototype *prototypeNode = TIntermFunctionPrototype *prototypeNode =
CreateInternalFunctionPrototypeNode(returnType, functionName.c_str(), functionId); CreateInternalFunctionPrototypeNode(returnType, functionName.c_str(), functionId);
TQualifier baseQualifier = EvqInOut; TType baseType(type);
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
// end up using mediump version of an indexing function for a highp value, if both mediump and
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends.
baseType.setPrecision(EbpHigh);
baseType.setQualifier(EvqInOut);
if (!write) if (!write)
baseQualifier = EvqIn; baseType.setQualifier(EvqIn);
TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier, symbolTable);
TIntermSymbol *baseParam = CreateBaseSymbol(baseType, symbolTable);
prototypeNode->getSequence()->push_back(baseParam); prototypeNode->getSequence()->push_back(baseParam);
TIntermSymbol *indexParam = CreateIndexSymbol(symbolTable); TIntermSymbol *indexParam = CreateIndexSymbol(symbolTable);
prototypeNode->getSequence()->push_back(indexParam); prototypeNode->getSequence()->push_back(indexParam);
...@@ -361,16 +364,16 @@ TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node, ...@@ -361,16 +364,16 @@ TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
} }
TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node, TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
TIntermTyped *index, TVariable *index,
TIntermTyped *writtenValue, TVariable *writtenValue,
const TSymbolUniqueId &functionId) const TSymbolUniqueId &functionId)
{ {
ASSERT(node->getOp() == EOpIndexIndirect); ASSERT(node->getOp() == EOpIndexIndirect);
TIntermSequence *arguments = new TIntermSequence(); TIntermSequence *arguments = new TIntermSequence();
// Deep copy the child nodes so that two pointers to the same node don't end up in the tree. // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
arguments->push_back(node->getLeft()->deepCopy()); arguments->push_back(node->getLeft()->deepCopy());
arguments->push_back(index->deepCopy()); arguments->push_back(CreateTempSymbolNode(index));
arguments->push_back(writtenValue); arguments->push_back(CreateTempSymbolNode(writtenValue));
std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), true); std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), true);
TIntermAggregate *indexedWriteCall = TIntermAggregate *indexedWriteCall =
...@@ -394,15 +397,14 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -394,15 +397,14 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// to this: // to this:
// int s0 = index_expr; v_expr[s0]; // int s0 = index_expr; v_expr[s0];
// Now v_expr[s0] can be safely executed several times without unintended side effects. // Now v_expr[s0] can be safely executed several times without unintended side effects.
nextTemporaryId(); TIntermDeclaration *indexVariableDeclaration = nullptr;
TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
// Init the temp variable holding the index EvqTemporary, &indexVariableDeclaration);
TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight()); insertStatementInParentBlock(indexVariableDeclaration);
insertStatementInParentBlock(initIndex);
mUsedTreeInsertion = true; mUsedTreeInsertion = true;
// Replace the index with the temp variable // Replace the index with the temp variable
TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType()); TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED); queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
} }
else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node)) else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
...@@ -473,34 +475,34 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -473,34 +475,34 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
{ {
indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type]; indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type];
} }
TType fieldType = GetFieldType(type);
TIntermSequence insertionsBefore; TIntermSequence insertionsBefore;
TIntermSequence insertionsAfter; TIntermSequence insertionsAfter;
// Store the index in a temporary signed int variable. // Store the index in a temporary signed int variable.
nextTemporaryId(); // s0 = index_expr;
TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight()); TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
TIntermDeclaration *initIndex = createTempInitDeclaration(indexInitializer); TIntermDeclaration *indexVariableDeclaration = nullptr;
initIndex->setLine(node->getLine()); TVariable *indexVariable = DeclareTempVariable(
insertionsBefore.push_back(initIndex); mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
insertionsBefore.push_back(indexVariableDeclaration);
// Create a node for referring to the index after the nextTemporaryId() call
// below.
TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
TIntermAggregate *indexingCall = // s1 = dyn_index(v_expr, s0);
CreateIndexFunctionCall(node, tempIndex, *indexingFunctionId); TIntermAggregate *indexingCall = CreateIndexFunctionCall(
node, CreateTempSymbolNode(indexVariable), *indexingFunctionId);
nextTemporaryId(); // From now on, creating temporary symbols that refer to the TIntermDeclaration *fieldVariableDeclaration = nullptr;
// field value. TVariable *fieldVariable = DeclareTempVariable(
insertionsBefore.push_back(createTempInitDeclaration(indexingCall)); mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
insertionsBefore.push_back(fieldVariableDeclaration);
// dyn_index_write(v_expr, s0, s1);
TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall( TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
node, tempIndex, createTempSymbol(fieldType), *indexedWriteFunctionId); node, indexVariable, fieldVariable, *indexedWriteFunctionId);
insertionsAfter.push_back(indexedWriteCall); insertionsAfter.push_back(indexedWriteCall);
insertStatementsInParentBlock(insertionsBefore, insertionsAfter); insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
queueReplacement(createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
// replace the node with s1
queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
mUsedTreeInsertion = true; mUsedTreeInsertion = true;
} }
else else
......
...@@ -192,9 +192,9 @@ void RemoveUnreferencedVariablesTraverser::removeVariableDeclaration(TIntermDecl ...@@ -192,9 +192,9 @@ void RemoveUnreferencedVariablesTraverser::removeVariableDeclaration(TIntermDecl
// Already an empty declaration - nothing to do. // Already an empty declaration - nothing to do.
return; return;
} }
queueReplacementWithParent(node, declarator, TVariable *emptyVariable = new TVariable(mSymbolTable, NewPoolTString(""),
new TIntermSymbol(mSymbolTable->getEmptySymbolId(), declarator->getType(), SymbolType::Empty);
TString(""), declarator->getType()), queueReplacementWithParent(node, declarator, new TIntermSymbol(emptyVariable),
OriginalNode::IS_DROPPED); OriginalNode::IS_DROPPED);
return; return;
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "compiler/translator/RewriteDoWhile.h" #include "compiler/translator/RewriteDoWhile.h"
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
namespace sh namespace sh
...@@ -70,29 +71,16 @@ class DoWhileRewriter : public TIntermTraverser ...@@ -70,29 +71,16 @@ class DoWhileRewriter : public TIntermTraverser
} }
// Found a loop to change. // Found a loop to change.
nextTemporaryId(); TType boolType(EbtBool);
TVariable *conditionVariable = CreateTempVariable(mSymbolTable, boolType);
TType boolType = TType(EbtBool);
// bool temp = false; // bool temp = false;
TIntermDeclaration *tempDeclaration = nullptr; TIntermDeclaration *tempDeclaration =
{ CreateTempInitDeclarationNode(conditionVariable, CreateBoolNode(false));
TConstantUnion *falseConstant = new TConstantUnion();
falseConstant->setBConst(false);
TIntermTyped *falseValue = new TIntermConstantUnion(falseConstant, boolType);
tempDeclaration = createTempInitDeclaration(falseValue);
}
// temp = true; // temp = true;
TIntermBinary *assignTrue = nullptr; TIntermBinary *assignTrue =
{ CreateTempAssignmentNode(conditionVariable, CreateBoolNode(true));
TConstantUnion *trueConstant = new TConstantUnion();
trueConstant->setBConst(true);
TIntermTyped *trueValue = new TIntermConstantUnion(trueConstant, boolType);
assignTrue = createTempAssignment(trueValue);
}
// if (temp) { // if (temp) {
// if (!CONDITION) { // if (!CONDITION) {
...@@ -114,17 +102,14 @@ class DoWhileRewriter : public TIntermTraverser ...@@ -114,17 +102,14 @@ class DoWhileRewriter : public TIntermTraverser
TIntermBlock *innerIfBlock = new TIntermBlock(); TIntermBlock *innerIfBlock = new TIntermBlock();
innerIfBlock->getSequence()->push_back(innerIf); innerIfBlock->getSequence()->push_back(innerIf);
breakIf = new TIntermIfElse(createTempSymbol(boolType), innerIfBlock, nullptr); breakIf = new TIntermIfElse(CreateTempSymbolNode(conditionVariable), innerIfBlock,
nullptr);
} }
// Assemble the replacement loops, reusing the do-while loop's body and inserting our // Assemble the replacement loops, reusing the do-while loop's body and inserting our
// statements at the front. // statements at the front.
TIntermLoop *newLoop = nullptr; TIntermLoop *newLoop = nullptr;
{ {
TConstantUnion *trueConstant = new TConstantUnion();
trueConstant->setBConst(true);
TIntermTyped *trueValue = new TIntermConstantUnion(trueConstant, boolType);
TIntermBlock *body = loop->getBody(); TIntermBlock *body = loop->getBody();
if (body == nullptr) if (body == nullptr)
{ {
...@@ -134,7 +119,7 @@ class DoWhileRewriter : public TIntermTraverser ...@@ -134,7 +119,7 @@ class DoWhileRewriter : public TIntermTraverser
sequence->insert(sequence->begin(), assignTrue); sequence->insert(sequence->begin(), assignTrue);
sequence->insert(sequence->begin(), breakIf); sequence->insert(sequence->begin(), breakIf);
newLoop = new TIntermLoop(ELoopWhile, nullptr, trueValue, nullptr, body); newLoop = new TIntermLoop(ELoopWhile, nullptr, CreateBoolNode(true), nullptr, body);
} }
TIntermSequence replacement; TIntermSequence replacement;
......
...@@ -69,9 +69,9 @@ TIntermNode *ElseBlockRewriter::rewriteIfElse(TIntermIfElse *ifElse) ...@@ -69,9 +69,9 @@ TIntermNode *ElseBlockRewriter::rewriteIfElse(TIntermIfElse *ifElse)
{ {
ASSERT(ifElse != nullptr); ASSERT(ifElse != nullptr);
nextTemporaryId(); TIntermDeclaration *storeCondition = nullptr;
TVariable *conditionVariable =
TIntermDeclaration *storeCondition = createTempInitDeclaration(ifElse->getCondition()); DeclareTempVariable(mSymbolTable, ifElse->getCondition(), EvqTemporary, &storeCondition);
TIntermBlock *falseBlock = nullptr; TIntermBlock *falseBlock = nullptr;
...@@ -91,14 +91,14 @@ TIntermNode *ElseBlockRewriter::rewriteIfElse(TIntermIfElse *ifElse) ...@@ -91,14 +91,14 @@ TIntermNode *ElseBlockRewriter::rewriteIfElse(TIntermIfElse *ifElse)
negatedElse->appendStatement(returnNode); negatedElse->appendStatement(returnNode);
} }
TIntermSymbol *conditionSymbolElse = createTempSymbol(boolType); TIntermSymbol *conditionSymbolElse = CreateTempSymbolNode(conditionVariable);
TIntermUnary *negatedCondition = new TIntermUnary(EOpLogicalNot, conditionSymbolElse); TIntermUnary *negatedCondition = new TIntermUnary(EOpLogicalNot, conditionSymbolElse);
TIntermIfElse *falseIfElse = TIntermIfElse *falseIfElse =
new TIntermIfElse(negatedCondition, ifElse->getFalseBlock(), negatedElse); new TIntermIfElse(negatedCondition, ifElse->getFalseBlock(), negatedElse);
falseBlock = EnsureBlock(falseIfElse); falseBlock = EnsureBlock(falseIfElse);
} }
TIntermSymbol *conditionSymbolSel = createTempSymbol(boolType); TIntermSymbol *conditionSymbolSel = CreateTempSymbolNode(conditionVariable);
TIntermIfElse *newIfElse = TIntermIfElse *newIfElse =
new TIntermIfElse(conditionSymbolSel, ifElse->getTrueBlock(), falseBlock); new TIntermIfElse(conditionSymbolSel, ifElse->getTrueBlock(), falseBlock);
......
...@@ -66,7 +66,7 @@ class ScalarizeArgsTraverser : public TIntermTraverser ...@@ -66,7 +66,7 @@ class ScalarizeArgsTraverser : public TIntermTraverser
// vec4 v(1, s0[0][0], s0[0][1], s0[0][2]); // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
// This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
// way the possible side effects of the constructor argument will only be evaluated once. // way the possible side effects of the constructor argument will only be evaluated once.
void createTempVariable(TIntermTyped *original); TVariable *createTempVariable(TIntermTyped *original);
std::vector<TIntermSequence> mBlockStack; std::vector<TIntermSequence> mBlockStack;
...@@ -129,10 +129,10 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -129,10 +129,10 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
ASSERT(size > 0); ASSERT(size > 0);
TIntermTyped *originalArg = originalArgNode->getAsTyped(); TIntermTyped *originalArg = originalArgNode->getAsTyped();
ASSERT(originalArg); ASSERT(originalArg);
createTempVariable(originalArg); TVariable *argVariable = createTempVariable(originalArg);
if (originalArg->isScalar()) if (originalArg->isScalar())
{ {
sequence->push_back(createTempSymbol(originalArg->getType())); sequence->push_back(CreateTempSymbolNode(argVariable));
size--; size--;
} }
else if (originalArg->isVector()) else if (originalArg->isVector())
...@@ -143,14 +143,14 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -143,14 +143,14 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
size -= repeat; size -= repeat;
for (int index = 0; index < repeat; ++index) for (int index = 0; index < repeat; ++index)
{ {
TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType()); TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index); TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index);
sequence->push_back(newNode); sequence->push_back(newNode);
} }
} }
else else
{ {
TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType()); TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
sequence->push_back(symbolNode); sequence->push_back(symbolNode);
size -= originalArg->getNominalSize(); size -= originalArg->getNominalSize();
} }
...@@ -165,7 +165,7 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -165,7 +165,7 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
size -= repeat; size -= repeat;
while (repeat > 0) while (repeat > 0)
{ {
TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType()); TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
TIntermBinary *newNode = TIntermBinary *newNode =
ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex); ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
sequence->push_back(newNode); sequence->push_back(newNode);
...@@ -180,7 +180,7 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -180,7 +180,7 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
} }
else else
{ {
TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType()); TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
sequence->push_back(symbolNode); sequence->push_back(symbolNode);
size -= originalArg->getCols() * originalArg->getRows(); size -= originalArg->getCols() * originalArg->getRows();
} }
...@@ -188,28 +188,29 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -188,28 +188,29 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
} }
} }
void ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original) TVariable *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
{ {
ASSERT(original); ASSERT(original);
nextTemporaryId();
TIntermDeclaration *decl = createTempInitDeclaration(original);
TType type = original->getType(); TType type(original->getType());
type.setQualifier(EvqTemporary);
if (mShaderType == GL_FRAGMENT_SHADER && type.getBasicType() == EbtFloat && if (mShaderType == GL_FRAGMENT_SHADER && type.getBasicType() == EbtFloat &&
type.getPrecision() == EbpUndefined) type.getPrecision() == EbpUndefined)
{ {
// We use the highest available precision for the temporary variable // We use the highest available precision for the temporary variable
// to avoid computing the actual precision using the rules defined // to avoid computing the actual precision using the rules defined
// in GLSL ES 1.0 Section 4.5.2. // in GLSL ES 1.0 Section 4.5.2.
TIntermBinary *init = decl->getSequence()->at(0)->getAsBinaryNode(); type.setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
init->getTypePointer()->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
init->getLeft()->getTypePointer()->setPrecision(mFragmentPrecisionHigh ? EbpHigh
: EbpMedium);
} }
TVariable *variable = CreateTempVariable(mSymbolTable, type);
ASSERT(mBlockStack.size() > 0); ASSERT(mBlockStack.size() > 0);
TIntermSequence &sequence = mBlockStack.back(); TIntermSequence &sequence = mBlockStack.back();
sequence.push_back(decl); TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
sequence.push_back(declaration);
return variable;
} }
} // namespace anonymous } // namespace anonymous
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "compiler/translator/SeparateExpressionsReturningArrays.h" #include "compiler/translator/SeparateExpressionsReturningArrays.h"
#include "compiler/translator/IntermNodePatternMatcher.h" #include "compiler/translator/IntermNodePatternMatcher.h"
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
namespace sh namespace sh
...@@ -73,11 +74,13 @@ bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -73,11 +74,13 @@ bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
// TODO(oetuaho): In some cases it would be more optimal to not add the temporary node, but just // TODO(oetuaho): In some cases it would be more optimal to not add the temporary node, but just
// use the original target of the assignment. Care must be taken so that this doesn't happen // use the original target of the assignment. Care must be taken so that this doesn't happen
// when the same array symbol is a target of assignment more than once in one expression. // when the same array symbol is a target of assignment more than once in one expression.
nextTemporaryId(); TIntermDeclaration *arrayVariableDeclaration;
insertions.push_back(createTempInitDeclaration(node->getLeft())); TVariable *arrayVariable =
DeclareTempVariable(mSymbolTable, node->getLeft(), EvqTemporary, &arrayVariableDeclaration);
insertions.push_back(arrayVariableDeclaration);
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
queueReplacement(createTempSymbol(node->getType()), OriginalNode::IS_DROPPED); queueReplacement(CreateTempSymbolNode(arrayVariable), OriginalNode::IS_DROPPED);
return false; return false;
} }
...@@ -94,13 +97,12 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate ...@@ -94,13 +97,12 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate
mFoundArrayExpression = true; mFoundArrayExpression = true;
nextTemporaryId(); TIntermDeclaration *arrayVariableDeclaration;
TVariable *arrayVariable = DeclareTempVariable(mSymbolTable, node->shallowCopy(), EvqTemporary,
&arrayVariableDeclaration);
insertStatementInParentBlock(arrayVariableDeclaration);
TIntermSequence insertions; queueReplacement(CreateTempSymbolNode(arrayVariable), OriginalNode::IS_DROPPED);
insertions.push_back(createTempInitDeclaration(node->shallowCopy()));
insertStatementsInParentBlock(insertions);
queueReplacement(createTempSymbol(node->getType()), OriginalNode::IS_DROPPED);
return false; return false;
} }
......
...@@ -151,7 +151,8 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -151,7 +151,8 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
if (mFoundLoopToChange) if (mFoundLoopToChange)
{ {
nextTemporaryId(); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
TVariable *conditionVariable = CreateTempVariable(mSymbolTable, boolType);
// Replace the loop condition with a boolean variable that's updated on each iteration. // Replace the loop condition with a boolean variable that's updated on each iteration.
TLoopType loopType = node->getType(); TLoopType loopType = node->getType();
...@@ -162,9 +163,9 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -162,9 +163,9 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
// into // into
// bool s0 = expr; // bool s0 = expr;
// while (s0) { { body; } s0 = expr; } // while (s0) { { body; } s0 = expr; }
TIntermSequence tempInitSeq; TIntermDeclaration *tempInitDeclaration =
tempInitSeq.push_back(createTempInitDeclaration(node->getCondition()->deepCopy())); CreateTempInitDeclarationNode(conditionVariable, node->getCondition()->deepCopy());
insertStatementsInParentBlock(tempInitSeq); insertStatementInParentBlock(tempInitDeclaration);
TIntermBlock *newBody = new TIntermBlock(); TIntermBlock *newBody = new TIntermBlock();
if (node->getBody()) if (node->getBody())
...@@ -172,13 +173,13 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -172,13 +173,13 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
newBody->getSequence()->push_back(node->getBody()); newBody->getSequence()->push_back(node->getBody());
} }
newBody->getSequence()->push_back( newBody->getSequence()->push_back(
createTempAssignment(node->getCondition()->deepCopy())); CreateTempAssignmentNode(conditionVariable, node->getCondition()->deepCopy()));
// Can't use queueReplacement to replace old body, since it may have been nullptr. // Can't use queueReplacement to replace old body, since it may have been nullptr.
// It's safe to do the replacements in place here - the new body will still be // It's safe to do the replacements in place here - the new body will still be
// traversed, but that won't create any problems. // traversed, but that won't create any problems.
node->setBody(newBody); node->setBody(newBody);
node->setCondition(createTempSymbol(node->getCondition()->getType())); node->setCondition(CreateTempSymbolNode(conditionVariable));
} }
else if (loopType == ELoopDoWhile) else if (loopType == ELoopDoWhile)
{ {
...@@ -192,9 +193,9 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -192,9 +193,9 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
// { body; } // { body; }
// s0 = expr; // s0 = expr;
// } while (s0); // } while (s0);
TIntermSequence tempInitSeq; TIntermDeclaration *tempInitDeclaration =
tempInitSeq.push_back(createTempInitDeclaration(CreateBoolNode(true))); CreateTempInitDeclarationNode(conditionVariable, CreateBoolNode(true));
insertStatementsInParentBlock(tempInitSeq); insertStatementInParentBlock(tempInitDeclaration);
TIntermBlock *newBody = new TIntermBlock(); TIntermBlock *newBody = new TIntermBlock();
if (node->getBody()) if (node->getBody())
...@@ -202,13 +203,13 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -202,13 +203,13 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
newBody->getSequence()->push_back(node->getBody()); newBody->getSequence()->push_back(node->getBody());
} }
newBody->getSequence()->push_back( newBody->getSequence()->push_back(
createTempAssignment(node->getCondition()->deepCopy())); CreateTempAssignmentNode(conditionVariable, node->getCondition()->deepCopy()));
// Can't use queueReplacement to replace old body, since it may have been nullptr. // Can't use queueReplacement to replace old body, since it may have been nullptr.
// It's safe to do the replacements in place here - the new body will still be // It's safe to do the replacements in place here - the new body will still be
// traversed, but that won't create any problems. // traversed, but that won't create any problems.
node->setBody(newBody); node->setBody(newBody);
node->setCondition(createTempSymbol(node->getCondition()->getType())); node->setCondition(CreateTempSymbolNode(conditionVariable));
} }
else if (loopType == ELoopFor) else if (loopType == ELoopFor)
{ {
...@@ -244,7 +245,8 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -244,7 +245,8 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
{ {
conditionInitializer = CreateBoolNode(true); conditionInitializer = CreateBoolNode(true);
} }
loopScopeSequence->push_back(createTempInitDeclaration(conditionInitializer)); loopScopeSequence->push_back(
CreateTempInitDeclarationNode(conditionVariable, conditionInitializer));
// Insert "{ body; }" in the while loop // Insert "{ body; }" in the while loop
TIntermBlock *whileLoopBody = new TIntermBlock(); TIntermBlock *whileLoopBody = new TIntermBlock();
...@@ -261,13 +263,13 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -261,13 +263,13 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
if (node->getCondition()) if (node->getCondition())
{ {
whileLoopBody->getSequence()->push_back( whileLoopBody->getSequence()->push_back(
createTempAssignment(node->getCondition()->deepCopy())); CreateTempAssignmentNode(conditionVariable, node->getCondition()->deepCopy()));
} }
// Create "while(s0) { whileLoopBody }" // Create "while(s0) { whileLoopBody }"
TIntermLoop *whileLoop = new TIntermLoop( TIntermLoop *whileLoop =
ELoopWhile, nullptr, createTempSymbol(conditionInitializer->getType()), nullptr, new TIntermLoop(ELoopWhile, nullptr, CreateTempSymbolNode(conditionVariable),
whileLoopBody); nullptr, whileLoopBody);
loopScope->getSequence()->push_back(whileLoop); loopScope->getSequence()->push_back(whileLoop);
queueReplacement(loopScope, OriginalNode::IS_DROPPED); queueReplacement(loopScope, OriginalNode::IS_DROPPED);
......
...@@ -39,6 +39,20 @@ TSymbol::TSymbol(TSymbolTable *symbolTable, ...@@ -39,6 +39,20 @@ TSymbol::TSymbol(TSymbolTable *symbolTable,
mExtension(extension) mExtension(extension)
{ {
ASSERT(mSymbolType == SymbolType::BuiltIn || mExtension == TExtension::UNDEFINED); ASSERT(mSymbolType == SymbolType::BuiltIn || mExtension == TExtension::UNDEFINED);
ASSERT(mName != nullptr || mSymbolType == SymbolType::AngleInternal ||
mSymbolType == SymbolType::NotResolved);
}
const TString &TSymbol::name() const
{
if (mName != nullptr)
{
return *mName;
}
ASSERT(mSymbolType == SymbolType::AngleInternal);
TInfoSinkBase symbolNameOut;
symbolNameOut << "s" << mUniqueId.get();
return *NewPoolTString(symbolNameOut.c_str());
} }
TVariable::TVariable(TSymbolTable *symbolTable, TVariable::TVariable(TSymbolTable *symbolTable,
...@@ -332,9 +346,10 @@ constexpr const TType *VectorType(const TType *type, int size) ...@@ -332,9 +346,10 @@ constexpr const TType *VectorType(const TType *type, int size)
} }
} }
TVariable *TSymbolTable::declareVariable(const TString *name, const TType &type) bool TSymbolTable::declareVariable(TVariable *variable)
{ {
return insertVariable(currentLevel(), name, type, SymbolType::UserDefined); ASSERT(variable->symbolType() == SymbolType::UserDefined);
return insertVariable(currentLevel(), variable);
} }
bool TSymbolTable::declareStructType(TStructure *str) bool TSymbolTable::declareStructType(TStructure *str)
...@@ -388,6 +403,12 @@ TVariable *TSymbolTable::insertVariableExt(ESymbolLevel level, ...@@ -388,6 +403,12 @@ TVariable *TSymbolTable::insertVariableExt(ESymbolLevel level,
return nullptr; return nullptr;
} }
bool TSymbolTable::insertVariable(ESymbolLevel level, TVariable *variable)
{
ASSERT(variable);
return insert(level, variable);
}
bool TSymbolTable::insertStructType(ESymbolLevel level, TStructure *str) bool TSymbolTable::insertStructType(ESymbolLevel level, TStructure *str)
{ {
ASSERT(str); ASSERT(str);
......
...@@ -47,6 +47,7 @@ enum class SymbolType ...@@ -47,6 +47,7 @@ enum class SymbolType
{ {
BuiltIn, BuiltIn,
UserDefined, UserDefined,
AngleInternal,
Empty, // Meaning symbol without a name. Empty, // Meaning symbol without a name.
NotResolved NotResolved
}; };
...@@ -66,7 +67,7 @@ class TSymbol : angle::NonCopyable ...@@ -66,7 +67,7 @@ class TSymbol : angle::NonCopyable
// don't delete name, it's from the pool // don't delete name, it's from the pool
} }
const TString &name() const { return *mName; } const TString &name() const;
virtual const TString &getMangledName() const { return name(); } virtual const TString &getMangledName() const { return name(); }
virtual bool isFunction() const { return false; } virtual bool isFunction() const { return false; }
virtual bool isVariable() const { return false; } virtual bool isVariable() const { return false; }
...@@ -89,6 +90,12 @@ class TSymbol : angle::NonCopyable ...@@ -89,6 +90,12 @@ class TSymbol : angle::NonCopyable
class TVariable : public TSymbol class TVariable : public TSymbol
{ {
public: public:
TVariable(TSymbolTable *symbolTable,
const TString *name,
const TType &t,
SymbolType symbolType,
TExtension ext = TExtension::UNDEFINED);
~TVariable() override {} ~TVariable() override {}
bool isVariable() const override { return true; } bool isVariable() const override { return true; }
TType &getType() { return type; } TType &getType() { return type; }
...@@ -100,13 +107,6 @@ class TVariable : public TSymbol ...@@ -100,13 +107,6 @@ class TVariable : public TSymbol
void shareConstPointer(const TConstantUnion *constArray) { unionArray = constArray; } void shareConstPointer(const TConstantUnion *constArray) { unionArray = constArray; }
private: private:
friend class TSymbolTable;
TVariable(TSymbolTable *symbolTable,
const TString *name,
const TType &t,
SymbolType symbolType,
TExtension ext = TExtension::UNDEFINED);
TType type; TType type;
const TConstantUnion *unionArray; const TConstantUnion *unionArray;
}; };
...@@ -330,7 +330,7 @@ const int GLOBAL_LEVEL = 5; ...@@ -330,7 +330,7 @@ const int GLOBAL_LEVEL = 5;
class TSymbolTable : angle::NonCopyable class TSymbolTable : angle::NonCopyable
{ {
public: public:
TSymbolTable() : mUniqueIdCounter(0), mUserDefinedUniqueIdsStart(-1), mEmptySymbolId(this) TSymbolTable() : mUniqueIdCounter(0), mUserDefinedUniqueIdsStart(-1)
{ {
// The symbol table cannot be used until push() is called, but // The symbol table cannot be used until push() is called, but
// the lack of an initial call to push() can be used to detect // the lack of an initial call to push() can be used to detect
...@@ -363,7 +363,7 @@ class TSymbolTable : angle::NonCopyable ...@@ -363,7 +363,7 @@ class TSymbolTable : angle::NonCopyable
// The declare* entry points are used when parsing and declare symbols at the current scope. // The declare* entry points are used when parsing and declare symbols at the current scope.
// They return the created symbol / true in case the declaration was successful, and nullptr / // They return the created symbol / true in case the declaration was successful, and nullptr /
// false if the declaration failed due to redefinition. // false if the declaration failed due to redefinition.
TVariable *declareVariable(const TString *name, const TType &type); bool declareVariable(TVariable *variable);
bool declareStructType(TStructure *str); bool declareStructType(TStructure *str);
bool declareInterfaceBlock(TInterfaceBlock *interfaceBlock); bool declareInterfaceBlock(TInterfaceBlock *interfaceBlock);
...@@ -375,6 +375,7 @@ class TSymbolTable : angle::NonCopyable ...@@ -375,6 +375,7 @@ class TSymbolTable : angle::NonCopyable
TExtension ext, TExtension ext,
const char *name, const char *name,
const TType &type); const TType &type);
bool insertVariable(ESymbolLevel level, TVariable *variable);
bool insertStructType(ESymbolLevel level, TStructure *str); bool insertStructType(ESymbolLevel level, TStructure *str);
bool insertInterfaceBlock(ESymbolLevel level, TInterfaceBlock *interfaceBlock); bool insertInterfaceBlock(ESymbolLevel level, TInterfaceBlock *interfaceBlock);
...@@ -543,11 +544,6 @@ class TSymbolTable : angle::NonCopyable ...@@ -543,11 +544,6 @@ class TSymbolTable : angle::NonCopyable
const TSymbolUniqueId nextUniqueId() { return TSymbolUniqueId(this); } const TSymbolUniqueId nextUniqueId() { return TSymbolUniqueId(this); }
// The empty symbol id is shared between all empty string ("") symbols. They are used in the
// AST for unused function parameters and struct type declarations that don't declare a
// variable, for example.
const TSymbolUniqueId &getEmptySymbolId() { return mEmptySymbolId; }
// Checks whether there is a built-in accessible by a shader with the specified version. // Checks whether there is a built-in accessible by a shader with the specified version.
bool hasUnmangledBuiltInForShaderVersion(const char *name, int shaderVersion); bool hasUnmangledBuiltInForShaderVersion(const char *name, int shaderVersion);
...@@ -587,8 +583,6 @@ class TSymbolTable : angle::NonCopyable ...@@ -587,8 +583,6 @@ class TSymbolTable : angle::NonCopyable
// TODO(oetuaho): Make this a compile-time constant once the symbol table is initialized at // TODO(oetuaho): Make this a compile-time constant once the symbol table is initialized at
// compile time. http://anglebug.com/1432 // compile time. http://anglebug.com/1432
int mUserDefinedUniqueIdsStart; int mUserDefinedUniqueIdsStart;
const TSymbolUniqueId mEmptySymbolId;
}; };
} // namespace sh } // namespace sh
......
...@@ -803,7 +803,9 @@ void TType::createSamplerSymbols(const TString &namePrefix, ...@@ -803,7 +803,9 @@ void TType::createSamplerSymbols(const TString &namePrefix,
} }
ASSERT(IsSampler(type)); ASSERT(IsSampler(type));
TIntermSymbol *symbol = new TIntermSymbol(symbolTable->nextUniqueId(), namePrefix, *this); TVariable *variable = new TVariable(symbolTable, NewPoolTString(namePrefix.c_str()), *this,
SymbolType::AngleInternal);
TIntermSymbol *symbol = new TIntermSymbol(variable);
outputSymbols->push_back(symbol); outputSymbols->push_back(symbol);
if (outputSymbolsToAPINames) if (outputSymbolsToAPINames)
{ {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "compiler/translator/UnfoldShortCircuitToIf.h" #include "compiler/translator/UnfoldShortCircuitToIf.h"
#include "compiler/translator/IntermNodePatternMatcher.h" #include "compiler/translator/IntermNodePatternMatcher.h"
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
namespace sh namespace sh
...@@ -75,23 +76,24 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -75,23 +76,24 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
TIntermSequence insertions; TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
nextTemporaryId(); TVariable *resultVariable = CreateTempVariable(mSymbolTable, boolType);
ASSERT(node->getLeft()->getType() == boolType); ASSERT(node->getLeft()->getType() == boolType);
insertions.push_back(createTempInitDeclaration(node->getLeft())); insertions.push_back(CreateTempInitDeclarationNode(resultVariable, node->getLeft()));
TIntermBlock *assignRightBlock = new TIntermBlock(); TIntermBlock *assignRightBlock = new TIntermBlock();
ASSERT(node->getRight()->getType() == boolType); ASSERT(node->getRight()->getType() == boolType);
assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight())); assignRightBlock->getSequence()->push_back(
CreateTempAssignmentNode(resultVariable, node->getRight()));
TIntermUnary *notTempSymbol = TIntermUnary *notTempSymbol =
new TIntermUnary(EOpLogicalNot, createTempSymbol(boolType)); new TIntermUnary(EOpLogicalNot, CreateTempSymbolNode(resultVariable));
TIntermIfElse *ifNode = new TIntermIfElse(notTempSymbol, assignRightBlock, nullptr); TIntermIfElse *ifNode = new TIntermIfElse(notTempSymbol, assignRightBlock, nullptr);
insertions.push_back(ifNode); insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
queueReplacement(createTempSymbol(boolType), OriginalNode::IS_DROPPED); queueReplacement(CreateTempSymbolNode(resultVariable), OriginalNode::IS_DROPPED);
return false; return false;
} }
case EOpLogicalAnd: case EOpLogicalAnd:
...@@ -101,22 +103,23 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -101,22 +103,23 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
// and then further simplifies down to "bool s = x; if(s) s = y;". // and then further simplifies down to "bool s = x; if(s) s = y;".
TIntermSequence insertions; TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary); TType boolType(EbtBool, EbpUndefined, EvqTemporary);
nextTemporaryId(); TVariable *resultVariable = CreateTempVariable(mSymbolTable, boolType);
ASSERT(node->getLeft()->getType() == boolType); ASSERT(node->getLeft()->getType() == boolType);
insertions.push_back(createTempInitDeclaration(node->getLeft())); insertions.push_back(CreateTempInitDeclarationNode(resultVariable, node->getLeft()));
TIntermBlock *assignRightBlock = new TIntermBlock(); TIntermBlock *assignRightBlock = new TIntermBlock();
ASSERT(node->getRight()->getType() == boolType); ASSERT(node->getRight()->getType() == boolType);
assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight())); assignRightBlock->getSequence()->push_back(
CreateTempAssignmentNode(resultVariable, node->getRight()));
TIntermIfElse *ifNode = TIntermIfElse *ifNode =
new TIntermIfElse(createTempSymbol(boolType), assignRightBlock, nullptr); new TIntermIfElse(CreateTempSymbolNode(resultVariable), assignRightBlock, nullptr);
insertions.push_back(ifNode); insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
queueReplacement(createTempSymbol(boolType), OriginalNode::IS_DROPPED); queueReplacement(CreateTempSymbolNode(resultVariable), OriginalNode::IS_DROPPED);
return false; return false;
} }
default: default:
...@@ -140,17 +143,19 @@ bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node ...@@ -140,17 +143,19 @@ bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node
// Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;" // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
TIntermSequence insertions; TIntermSequence insertions;
nextTemporaryId(); TIntermDeclaration *tempDeclaration = nullptr;
TVariable *resultVariable =
TIntermDeclaration *tempDeclaration = createTempDeclaration(node->getType()); DeclareTempVariable(mSymbolTable, node->getType(), EvqTemporary, &tempDeclaration);
insertions.push_back(tempDeclaration); insertions.push_back(tempDeclaration);
TIntermBlock *trueBlock = new TIntermBlock(); TIntermBlock *trueBlock = new TIntermBlock();
TIntermBinary *trueAssignment = createTempAssignment(node->getTrueExpression()); TIntermBinary *trueAssignment =
CreateTempAssignmentNode(resultVariable, node->getTrueExpression());
trueBlock->getSequence()->push_back(trueAssignment); trueBlock->getSequence()->push_back(trueAssignment);
TIntermBlock *falseBlock = new TIntermBlock(); TIntermBlock *falseBlock = new TIntermBlock();
TIntermBinary *falseAssignment = createTempAssignment(node->getFalseExpression()); TIntermBinary *falseAssignment =
CreateTempAssignmentNode(resultVariable, node->getFalseExpression());
falseBlock->getSequence()->push_back(falseAssignment); falseBlock->getSequence()->push_back(falseAssignment);
TIntermIfElse *ifNode = TIntermIfElse *ifNode =
...@@ -159,7 +164,7 @@ bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node ...@@ -159,7 +164,7 @@ bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
TIntermSymbol *ternaryResult = createTempSymbol(node->getType()); TIntermSymbol *ternaryResult = CreateTempSymbolNode(resultVariable);
queueReplacement(ternaryResult, OriginalNode::IS_DROPPED); queueReplacement(ternaryResult, OriginalNode::IS_DROPPED);
return false; return false;
......
...@@ -384,9 +384,6 @@ void UniformHLSL::uniformsHeader(TInfoSinkBase &out, ...@@ -384,9 +384,6 @@ void UniformHLSL::uniformsHeader(TInfoSinkBase &out,
for (TIntermSymbol *sampler : samplerSymbols) for (TIntermSymbol *sampler : samplerSymbols)
{ {
const TType &samplerType = sampler->getType(); const TType &samplerType = sampler->getType();
// Will use angle_ prefix instead of regular prefix.
sampler->setInternal(true);
const TName &samplerName = sampler->getName(); const TName &samplerName = sampler->getName();
if (outputType == SH_HLSL_4_1_OUTPUT) if (outputType == SH_HLSL_4_1_OUTPUT)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <set> #include <set>
#include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNode.h"
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
namespace sh namespace sh
...@@ -173,12 +174,13 @@ void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor( ...@@ -173,12 +174,13 @@ void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor(
TType vecType = node->getType(); TType vecType = node->getType();
vecType.setQualifier(EvqTemporary); vecType.setQualifier(EvqTemporary);
nextTemporaryId();
// gvec s0 = gvec(a); // gvec s0 = gvec(a);
// s0 is called "tempAssignmentTarget" below. // s0 is called "tempAssignmentTarget" below.
TIntermTyped *tempAssignmentTargetInitializer = Vectorize(left->deepCopy(), vecType, nullptr); TIntermTyped *tempAssignmentTargetInitializer = Vectorize(left->deepCopy(), vecType, nullptr);
TIntermDeclaration *tempAssignmentTargetDeclaration = TIntermDeclaration *tempAssignmentTargetDeclaration = nullptr;
createTempInitDeclaration(tempAssignmentTargetInitializer); TVariable *tempAssignmentTarget =
DeclareTempVariable(mSymbolTable, tempAssignmentTargetInitializer, EvqTemporary,
&tempAssignmentTargetDeclaration);
// s0 *= b // s0 *= b
TOperator compoundAssignmentOp = argBinary->getOp(); TOperator compoundAssignmentOp = argBinary->getOp();
...@@ -186,14 +188,14 @@ void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor( ...@@ -186,14 +188,14 @@ void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor(
{ {
compoundAssignmentOp = EOpVectorTimesScalarAssign; compoundAssignmentOp = EOpVectorTimesScalarAssign;
} }
TIntermBinary *replacementCompoundAssignment = TIntermBinary *replacementCompoundAssignment = new TIntermBinary(
new TIntermBinary(compoundAssignmentOp, createTempSymbol(vecType), right->deepCopy()); compoundAssignmentOp, CreateTempSymbolNode(tempAssignmentTarget), right->deepCopy());
// s0.x // s0.x
TVector<int> swizzleXOffset; TVector<int> swizzleXOffset;
swizzleXOffset.push_back(0); swizzleXOffset.push_back(0);
TIntermSwizzle *tempAssignmentTargetX = TIntermSwizzle *tempAssignmentTargetX =
new TIntermSwizzle(createTempSymbol(vecType), swizzleXOffset); new TIntermSwizzle(CreateTempSymbolNode(tempAssignmentTarget), swizzleXOffset);
// a = s0.x // a = s0.x
TIntermBinary *replacementAssignBackToTarget = TIntermBinary *replacementAssignBackToTarget =
new TIntermBinary(EOpAssign, left->deepCopy(), tempAssignmentTargetX); new TIntermBinary(EOpAssign, left->deepCopy(), tempAssignmentTargetX);
...@@ -202,8 +204,8 @@ void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor( ...@@ -202,8 +204,8 @@ void VectorizeVectorScalarArithmeticTraverser::replaceAssignInsideConstructor(
TIntermBinary *replacementSequenceLeft = TIntermBinary *replacementSequenceLeft =
new TIntermBinary(EOpComma, replacementCompoundAssignment, replacementAssignBackToTarget); new TIntermBinary(EOpComma, replacementCompoundAssignment, replacementAssignBackToTarget);
// (s0 *= b, a = s0.x), s0 // (s0 *= b, a = s0.x), s0
TIntermBinary *replacementSequence = TIntermBinary *replacementSequence = new TIntermBinary(
new TIntermBinary(EOpComma, replacementSequenceLeft, createTempSymbol(vecType)); EOpComma, replacementSequenceLeft, CreateTempSymbolNode(tempAssignmentTarget));
insertStatementInParentBlock(tempAssignmentTargetDeclaration); insertStatementInParentBlock(tempAssignmentTargetDeclaration);
queueReplacement(replacementSequence, OriginalNode::IS_DROPPED); queueReplacement(replacementSequence, OriginalNode::IS_DROPPED);
......
...@@ -66,7 +66,9 @@ TIntermTyped *CreateLValueNode(const TString &lValueName, const TType &type) ...@@ -66,7 +66,9 @@ TIntermTyped *CreateLValueNode(const TString &lValueName, const TType &type)
{ {
// We're using a dummy symbol table here, don't need to assign proper symbol ids to these nodes. // We're using a dummy symbol table here, don't need to assign proper symbol ids to these nodes.
TSymbolTable symbolTable; TSymbolTable symbolTable;
return new TIntermSymbol(symbolTable.nextUniqueId(), lValueName, type); TVariable *variable = new TVariable(&symbolTable, NewPoolTString(lValueName.c_str()), type,
SymbolType::UserDefined);
return new TIntermSymbol(variable);
} }
ExpectedLValues CreateIndexedLValueNodeList(const TString &lValueName, ExpectedLValues CreateIndexedLValueNodeList(const TString &lValueName,
...@@ -78,8 +80,9 @@ ExpectedLValues CreateIndexedLValueNodeList(const TString &lValueName, ...@@ -78,8 +80,9 @@ ExpectedLValues CreateIndexedLValueNodeList(const TString &lValueName,
// We're using a dummy symbol table here, don't need to assign proper symbol ids to these nodes. // We're using a dummy symbol table here, don't need to assign proper symbol ids to these nodes.
TSymbolTable symbolTable; TSymbolTable symbolTable;
TIntermSymbol *arraySymbol = TVariable *variable = new TVariable(&symbolTable, NewPoolTString(lValueName.c_str()),
new TIntermSymbol(symbolTable.nextUniqueId(), lValueName, elementType); elementType, SymbolType::UserDefined);
TIntermSymbol *arraySymbol = new TIntermSymbol(variable);
ExpectedLValues expected(arraySize); ExpectedLValues expected(arraySize);
for (unsigned index = 0u; index < arraySize; ++index) for (unsigned index = 0u; index < arraySize; ++index)
......
...@@ -38,17 +38,18 @@ class IntermNodeTest : public testing::Test ...@@ -38,17 +38,18 @@ class IntermNodeTest : public testing::Test
{ {
TInfoSinkBase symbolNameOut; TInfoSinkBase symbolNameOut;
symbolNameOut << "test" << mUniqueIndex; symbolNameOut << "test" << mUniqueIndex;
TString symbolName = symbolNameOut.c_str(); TString *symbolName = NewPoolTString(symbolNameOut.c_str());
++mUniqueIndex; ++mUniqueIndex;
// We're using a dummy symbol table here, don't need to assign proper symbol ids to these // We're using a dummy symbol table here, don't need to assign proper symbol ids to these
// nodes. // nodes.
TSymbolTable symbolTable; TSymbolTable symbolTable;
TType variableType(type);
TIntermSymbol *node = new TIntermSymbol(symbolTable.nextUniqueId(), symbolName, type); variableType.setQualifier(EvqTemporary);
TVariable *variable =
new TVariable(&symbolTable, symbolName, variableType, SymbolType::AngleInternal);
TIntermSymbol *node = new TIntermSymbol(variable);
node->setLine(createUniqueSourceLoc()); node->setLine(createUniqueSourceLoc());
node->setInternal(true);
node->getTypePointer()->setQualifier(EvqTemporary);
return node; return node;
} }
...@@ -126,9 +127,10 @@ TEST_F(IntermNodeTest, DeepCopySymbolNode) ...@@ -126,9 +127,10 @@ TEST_F(IntermNodeTest, DeepCopySymbolNode)
// We're using a dummy symbol table here, don't need to assign proper symbol ids to these nodes. // We're using a dummy symbol table here, don't need to assign proper symbol ids to these nodes.
TSymbolTable symbolTable; TSymbolTable symbolTable;
TIntermSymbol *original = new TIntermSymbol(symbolTable.nextUniqueId(), TString("name"), type); TVariable *variable =
new TVariable(&symbolTable, NewPoolTString("name"), type, SymbolType::AngleInternal);
TIntermSymbol *original = new TIntermSymbol(variable);
original->setLine(getTestSourceLoc()); original->setLine(getTestSourceLoc());
original->setInternal(true);
TIntermTyped *copy = original->deepCopy(); TIntermTyped *copy = original->deepCopy();
checkSymbolCopy(original, copy); checkSymbolCopy(original, copy);
checkTestSourceLoc(copy->getLine()); checkTestSourceLoc(copy->getLine());
......
...@@ -30,8 +30,7 @@ class RewriteDoWhileCrashTest : public ShaderCompileTreeTest ...@@ -30,8 +30,7 @@ class RewriteDoWhileCrashTest : public ShaderCompileTreeTest
} }
}; };
// Make sure that the RewriteDoWhile step doesn't crash due to creating temp symbols before calling // Make sure that the RewriteDoWhile step doesn't crash. Regression test.
// nextTemporaryId(). Regression test.
TEST_F(RewriteDoWhileCrashTest, RunsSuccessfully) TEST_F(RewriteDoWhileCrashTest, RunsSuccessfully)
{ {
const std::string &shaderString = const std::string &shaderString =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment