Commit 68981eb5 by Olli Etuaho Committed by Commit Bot

Track parameter qualifiers of functions in call nodes

We now add a reference to TFunction to all TIntermAggregate nodes where it is possible, including built-in ops. We also make sure that internal TFunctions added in traversers have correct parameter qualifiers. This makes TLValueTrackingTraverser much simpler. Instead of storing traversed functions or looking up builtin functions from the symbol table, determining which function parameters are out parameters can now be done simply by looking it up from the function symbol associated with the aggregate node. Symbol instances are no longer deleted when a symbol table level goes out of scope, and TFunction destructor no longer clears the parameters. They're all either statically allocated or pool allocated, so this does not result in leaks. TEST=angle_unittests BUG=angleproject:2267 Change-Id: I57e5570da5b5a69a98a8778da3c2dc82b6284738 Reviewed-on: https://chromium-review.googlesource.com/881324 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org>
parent 12da5e75
...@@ -128,9 +128,16 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit ...@@ -128,9 +128,16 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit
changedFunction.returnValueVariable = changedFunction.returnValueVariable =
new TVariable(mSymbolTable, mReturnValueVariableName, returnValueVariableType, new TVariable(mSymbolTable, mReturnValueVariableName, returnValueVariableType,
SymbolType::AngleInternal); SymbolType::AngleInternal);
changedFunction.func = new TFunction(mSymbolTable, &node->getFunction()->name(), TFunction *func = new TFunction(mSymbolTable, &node->getFunction()->name(),
StaticType::GetBasic<EbtVoid>(), StaticType::GetBasic<EbtVoid>(),
node->getFunction()->symbolType(), false); node->getFunction()->symbolType(), false);
for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
{
func->addParameter(node->getFunction()->getParam(i));
}
func->addParameter(TConstParameter(
mReturnValueVariableName, static_cast<const TType *>(returnValueVariableType)));
changedFunction.func = func;
mChangedFunctions[functionId.get()] = changedFunction; mChangedFunctions[functionId.get()] = changedFunction;
} }
TIntermFunctionPrototype *replacement = TIntermFunctionPrototype *replacement =
......
...@@ -402,7 +402,7 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root, ...@@ -402,7 +402,7 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root,
} }
if (shouldRunLoopAndIndexingValidation(compileOptions) && if (shouldRunLoopAndIndexingValidation(compileOptions) &&
!ValidateLimitations(root, shaderType, &symbolTable, shaderVersion, &mDiagnostics)) !ValidateLimitations(root, shaderType, &symbolTable, &mDiagnostics))
{ {
return false; return false;
} }
...@@ -533,14 +533,14 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root, ...@@ -533,14 +533,14 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root,
SimplifyLoopConditions(root, SimplifyLoopConditions(root,
IntermNodePatternMatcher::kMultiDeclaration | IntermNodePatternMatcher::kMultiDeclaration |
IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized, IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized,
&getSymbolTable(), getShaderVersion()); &getSymbolTable());
// Note that separate declarations need to be run before other AST transformations that // Note that separate declarations need to be run before other AST transformations that
// generate new statements from expressions. // generate new statements from expressions.
SeparateDeclarations(root); SeparateDeclarations(root);
SplitSequenceOperator(root, IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized, SplitSequenceOperator(root, IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized,
&getSymbolTable(), getShaderVersion()); &getSymbolTable());
RemoveArrayLengthMethod(root); RemoveArrayLengthMethod(root);
...@@ -630,7 +630,7 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root, ...@@ -630,7 +630,7 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root,
SimplifyLoopConditions(root, SimplifyLoopConditions(root,
IntermNodePatternMatcher::kArrayDeclaration | IntermNodePatternMatcher::kArrayDeclaration |
IntermNodePatternMatcher::kNamelessStructDeclaration, IntermNodePatternMatcher::kNamelessStructDeclaration,
&getSymbolTable(), getShaderVersion()); &getSymbolTable());
} }
InitializeUninitializedLocals(root, getShaderVersion(), canUseLoopsToInitialize, InitializeUninitializedLocals(root, getShaderVersion(), canUseLoopsToInitialize,
......
...@@ -470,9 +470,11 @@ bool ParentConstructorTakesCareOfRounding(TIntermNode *parent, TIntermTyped *nod ...@@ -470,9 +470,11 @@ bool ParentConstructorTakesCareOfRounding(TIntermNode *parent, TIntermTyped *nod
} // namespace anonymous } // namespace anonymous
EmulatePrecision::EmulatePrecision(TSymbolTable *symbolTable, int shaderVersion) EmulatePrecision::EmulatePrecision(TSymbolTable *symbolTable)
: TLValueTrackingTraverser(true, true, true, symbolTable, shaderVersion), : TLValueTrackingTraverser(true, true, true, symbolTable),
mDeclaringVariables(false) mDeclaringVariables(false),
mParamXName(NewPoolTString("x")),
mParamYName(NewPoolTString("y"))
{ {
} }
...@@ -705,17 +707,23 @@ bool EmulatePrecision::SupportedInLanguage(const ShShaderOutput outputLanguage) ...@@ -705,17 +707,23 @@ bool EmulatePrecision::SupportedInLanguage(const ShShaderOutput outputLanguage)
} }
} }
TFunction *EmulatePrecision::getInternalFunction(TString *functionName, const TFunction *EmulatePrecision::getInternalFunction(TString *functionName,
const TType &returnType, const TType &returnType,
TIntermSequence *arguments, TIntermSequence *arguments,
bool knownToNotHaveSideEffects) const TVector<TConstParameter> &parameters,
bool knownToNotHaveSideEffects)
{ {
TString mangledName = TFunction::GetMangledNameFromCall(*functionName, *arguments); TString mangledName = TFunction::GetMangledNameFromCall(*functionName, *arguments);
if (mInternalFunctions.find(mangledName) == mInternalFunctions.end()) if (mInternalFunctions.find(mangledName) == mInternalFunctions.end())
{ {
mInternalFunctions[mangledName] = TFunction *func = new TFunction(mSymbolTable, functionName, new TType(returnType),
new TFunction(mSymbolTable, functionName, new TType(returnType), SymbolType::AngleInternal, knownToNotHaveSideEffects);
SymbolType::AngleInternal, knownToNotHaveSideEffects); ASSERT(parameters.size() == arguments->size());
for (size_t i = 0; i < parameters.size(); ++i)
{
func->addParameter(parameters[i]);
}
mInternalFunctions[mangledName] = func;
} }
return mInternalFunctions[mangledName]; return mInternalFunctions[mangledName];
} }
...@@ -730,8 +738,16 @@ TIntermAggregate *EmulatePrecision::createRoundingFunctionCallNode(TIntermTyped ...@@ -730,8 +738,16 @@ TIntermAggregate *EmulatePrecision::createRoundingFunctionCallNode(TIntermTyped
TString *functionName = NewPoolTString(roundFunctionName); TString *functionName = NewPoolTString(roundFunctionName);
TIntermSequence *arguments = new TIntermSequence(); TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(roundedChild); arguments->push_back(roundedChild);
TVector<TConstParameter> parameters;
TType *paramType = new TType(roundedChild->getType());
paramType->setPrecision(EbpHigh);
paramType->setQualifier(EvqIn);
parameters.push_back(TConstParameter(mParamXName, static_cast<const TType *>(paramType)));
return TIntermAggregate::CreateRawFunctionCall( return TIntermAggregate::CreateRawFunctionCall(
*getInternalFunction(functionName, roundedChild->getType(), arguments, true), arguments); *getInternalFunction(functionName, roundedChild->getType(), arguments, parameters, true),
arguments);
} }
TIntermAggregate *EmulatePrecision::createCompoundAssignmentFunctionCallNode(TIntermTyped *left, TIntermAggregate *EmulatePrecision::createCompoundAssignmentFunctionCallNode(TIntermTyped *left,
...@@ -747,8 +763,20 @@ TIntermAggregate *EmulatePrecision::createCompoundAssignmentFunctionCallNode(TIn ...@@ -747,8 +763,20 @@ TIntermAggregate *EmulatePrecision::createCompoundAssignmentFunctionCallNode(TIn
TIntermSequence *arguments = new TIntermSequence(); TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(left); arguments->push_back(left);
arguments->push_back(right); arguments->push_back(right);
TVector<TConstParameter> parameters;
TType *leftParamType = new TType(left->getType());
leftParamType->setPrecision(EbpHigh);
leftParamType->setQualifier(EvqOut);
parameters.push_back(TConstParameter(mParamXName, static_cast<const TType *>(leftParamType)));
TType *rightParamType = new TType(right->getType());
rightParamType->setPrecision(EbpHigh);
rightParamType->setQualifier(EvqIn);
parameters.push_back(TConstParameter(mParamYName, static_cast<const TType *>(rightParamType)));
return TIntermAggregate::CreateRawFunctionCall( return TIntermAggregate::CreateRawFunctionCall(
*getInternalFunction(functionName, left->getType(), arguments, false), arguments); *getInternalFunction(functionName, left->getType(), arguments, parameters, false),
arguments);
} }
} // namespace sh } // namespace sh
...@@ -24,7 +24,7 @@ namespace sh ...@@ -24,7 +24,7 @@ namespace sh
class EmulatePrecision : public TLValueTrackingTraverser class EmulatePrecision : public TLValueTrackingTraverser
{ {
public: public:
EmulatePrecision(TSymbolTable *symbolTable, int shaderVersion); EmulatePrecision(TSymbolTable *symbolTable);
void visitSymbol(TIntermSymbol *node) override; void visitSymbol(TIntermSymbol *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
...@@ -59,10 +59,11 @@ class EmulatePrecision : public TLValueTrackingTraverser ...@@ -59,10 +59,11 @@ class EmulatePrecision : public TLValueTrackingTraverser
} }
}; };
TFunction *getInternalFunction(TString *functionName, const TFunction *getInternalFunction(TString *functionName,
const TType &returnType, const TType &returnType,
TIntermSequence *arguments, TIntermSequence *arguments,
bool knownToNotHaveSideEffects); const TVector<TConstParameter> &parameters,
bool knownToNotHaveSideEffects);
TIntermAggregate *createRoundingFunctionCallNode(TIntermTyped *roundedChild); TIntermAggregate *createRoundingFunctionCallNode(TIntermTyped *roundedChild);
TIntermAggregate *createCompoundAssignmentFunctionCallNode(TIntermTyped *left, TIntermAggregate *createCompoundAssignmentFunctionCallNode(TIntermTyped *left,
TIntermTyped *right, TIntermTyped *right,
...@@ -75,9 +76,12 @@ class EmulatePrecision : public TLValueTrackingTraverser ...@@ -75,9 +76,12 @@ class EmulatePrecision : public TLValueTrackingTraverser
EmulationSet mEmulateCompoundDiv; EmulationSet mEmulateCompoundDiv;
// Map from mangled name to function. // Map from mangled name to function.
TMap<TString, TFunction *> mInternalFunctions; TMap<TString, const TFunction *> mInternalFunctions;
bool mDeclaringVariables; bool mDeclaringVariables;
const TString *mParamXName;
const TString *mParamYName;
}; };
} // namespace sh } // namespace sh
......
...@@ -364,7 +364,7 @@ TIntermAggregate *TIntermAggregate::CreateConstructor(const TType &type, ...@@ -364,7 +364,7 @@ TIntermAggregate *TIntermAggregate::CreateConstructor(const TType &type,
return new TIntermAggregate(nullptr, type, EOpConstruct, arguments); return new TIntermAggregate(nullptr, type, EOpConstruct, arguments);
} }
TIntermAggregate *TIntermAggregate::Create(const TType &type, TIntermAggregate *TIntermAggregate::Create(const TFunction &func,
TOperator op, TOperator op,
TIntermSequence *arguments) TIntermSequence *arguments)
{ {
...@@ -372,7 +372,7 @@ TIntermAggregate *TIntermAggregate::Create(const TType &type, ...@@ -372,7 +372,7 @@ TIntermAggregate *TIntermAggregate::Create(const TType &type,
ASSERT(op != EOpCallInternalRawFunction); // Should use CreateRawFunctionCall ASSERT(op != EOpCallInternalRawFunction); // Should use CreateRawFunctionCall
ASSERT(op != EOpCallBuiltInFunction); // Should use CreateBuiltInFunctionCall ASSERT(op != EOpCallBuiltInFunction); // Should use CreateBuiltInFunctionCall
ASSERT(op != EOpConstruct); // Should use CreateConstructor ASSERT(op != EOpConstruct); // Should use CreateConstructor
return new TIntermAggregate(nullptr, type, op, arguments); return new TIntermAggregate(&func, func.getReturnType(), op, arguments);
} }
TIntermAggregate::TIntermAggregate(const TFunction *func, TIntermAggregate::TIntermAggregate(const TFunction *func,
......
...@@ -575,7 +575,9 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -575,7 +575,9 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TIntermSequence *arguments); TIntermSequence *arguments);
static TIntermAggregate *CreateConstructor(const TType &type, static TIntermAggregate *CreateConstructor(const TType &type,
TIntermSequence *arguments); TIntermSequence *arguments);
static TIntermAggregate *Create(const TType &type, TOperator op, TIntermSequence *arguments); static TIntermAggregate *Create(const TFunction &func,
TOperator op,
TIntermSequence *arguments);
~TIntermAggregate() {} ~TIntermAggregate() {}
// Note: only supported for nodes that can be a part of an expression. // Note: only supported for nodes that can be a part of an expression.
......
...@@ -264,7 +264,7 @@ TIntermTyped *CreateBuiltInFunctionCallNode(const TString &name, ...@@ -264,7 +264,7 @@ TIntermTyped *CreateBuiltInFunctionCallNode(const TString &name,
{ {
return new TIntermUnary(op, arguments->at(0)->getAsTyped()); return new TIntermUnary(op, arguments->at(0)->getAsTyped());
} }
return TIntermAggregate::Create(fn->getReturnType(), op, arguments); return TIntermAggregate::Create(*fn, op, arguments);
} }
return TIntermAggregate::CreateBuiltInFunctionCall(*fn, arguments); return TIntermAggregate::CreateBuiltInFunctionCall(*fn, arguments);
} }
......
...@@ -176,24 +176,6 @@ void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement) ...@@ -176,24 +176,6 @@ void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement)
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
} }
void TLValueTrackingTraverser::addToFunctionMap(const TSymbolUniqueId &id,
TIntermSequence *paramSequence)
{
mFunctionMap[id.get()] = paramSequence;
}
bool TLValueTrackingTraverser::isInFunctionMap(const TIntermAggregate *callNode) const
{
ASSERT(callNode->getOp() == EOpCallFunctionInAST);
return (mFunctionMap.find(callNode->getFunction()->uniqueId().get()) != mFunctionMap.end());
}
TIntermSequence *TLValueTrackingTraverser::getFunctionParameters(const TIntermAggregate *callNode)
{
ASSERT(isInFunctionMap(callNode));
return mFunctionMap[callNode->getFunction()->uniqueId().get()];
}
void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter) void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
{ {
mInFunctionCallOutParameter = inOutParameter; mInFunctionCallOutParameter = inOutParameter;
...@@ -662,24 +644,14 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent, ...@@ -662,24 +644,14 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit, TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit,
bool inVisit, bool inVisit,
bool postVisit, bool postVisit,
TSymbolTable *symbolTable, TSymbolTable *symbolTable)
int shaderVersion)
: TIntermTraverser(preVisit, inVisit, postVisit, symbolTable), : TIntermTraverser(preVisit, inVisit, postVisit, symbolTable),
mOperatorRequiresLValue(false), mOperatorRequiresLValue(false),
mInFunctionCallOutParameter(false), mInFunctionCallOutParameter(false)
mShaderVersion(shaderVersion)
{ {
ASSERT(symbolTable); ASSERT(symbolTable);
} }
void TLValueTrackingTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{
TIntermSequence *sequence = node->getSequence();
addToFunctionMap(node->getFunction()->uniqueId(), sequence);
TIntermTraverser::traverseFunctionPrototype(node);
}
void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
...@@ -693,81 +665,31 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -693,81 +665,31 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
if (visit) if (visit)
{ {
if (node->getOp() == EOpCallFunctionInAST) size_t paramIndex = 0u;
for (auto *child : *sequence)
{ {
if (isInFunctionMap(node)) if (node->getFunction())
{ {
TIntermSequence *params = getFunctionParameters(node); // Both built-ins and user defined functions should have the function symbol set.
TIntermSequence::iterator paramIter = params->begin(); ASSERT(paramIndex < node->getFunction()->getParamCount());
for (auto *child : *sequence) TQualifier qualifier =
{ node->getFunction()->getParam(paramIndex).type->getQualifier();
ASSERT(paramIter != params->end()); setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
TQualifier qualifier = (*paramIter)->getAsTyped()->getQualifier(); ++paramIndex;
setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitAggregate(InVisit, node);
}
++paramIter;
}
} }
else else
{ {
// The node might not be in the function map in case we're in the middle of ASSERT(node->isConstructor());
// transforming the AST, and have inserted function call nodes without inserting the
// function definitions yet.
setInFunctionCallOutParameter(false);
for (auto *child : *sequence)
{
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitAggregate(InVisit, node);
}
}
} }
setInFunctionCallOutParameter(false); child->traverse(this);
} if (visit && inVisit)
else
{
// Find the built-in function corresponding to this op so that we can determine the
// in/out qualifiers of its parameters.
const TFunction *builtInFunc = nullptr;
if (!node->isFunctionCall() && !node->isConstructor())
{
builtInFunc = static_cast<const TFunction *>(
mSymbolTable->findBuiltIn(node->getSymbolTableMangledName(), mShaderVersion));
}
size_t paramIndex = 0;
for (auto *child : *sequence)
{ {
// This assumes that raw functions called with if (child != sequence->back())
// EOpCallInternalRawFunction don't have out parameters. visit = visitAggregate(InVisit, node);
TQualifier qualifier = EvqIn;
if (builtInFunc != nullptr)
qualifier = builtInFunc->getParam(paramIndex).type->getQualifier();
setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitAggregate(InVisit, node);
}
++paramIndex;
} }
setInFunctionCallOutParameter(false);
} }
setInFunctionCallOutParameter(false);
} }
if (visit && postVisit) if (visit && postVisit)
......
...@@ -284,13 +284,11 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -284,13 +284,11 @@ class TLValueTrackingTraverser : public TIntermTraverser
TLValueTrackingTraverser(bool preVisit, TLValueTrackingTraverser(bool preVisit,
bool inVisit, bool inVisit,
bool postVisit, bool postVisit,
TSymbolTable *symbolTable, TSymbolTable *symbolTable);
int shaderVersion);
virtual ~TLValueTrackingTraverser() {} virtual ~TLValueTrackingTraverser() {}
void traverseBinary(TIntermBinary *node) final; void traverseBinary(TIntermBinary *node) final;
void traverseUnary(TIntermUnary *node) final; void traverseUnary(TIntermUnary *node) final;
void traverseFunctionPrototype(TIntermFunctionPrototype *node) final;
void traverseAggregate(TIntermAggregate *node) final; void traverseAggregate(TIntermAggregate *node) final;
protected: protected:
...@@ -309,27 +307,12 @@ class TLValueTrackingTraverser : public TIntermTraverser ...@@ -309,27 +307,12 @@ class TLValueTrackingTraverser : public TIntermTraverser
} }
bool operatorRequiresLValue() const { return mOperatorRequiresLValue; } bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
// Add a function encountered during traversal to the function map.
void addToFunctionMap(const TSymbolUniqueId &id, TIntermSequence *paramSequence);
// Return true if the prototype or definition of the function being called has been encountered
// during traversal.
bool isInFunctionMap(const TIntermAggregate *callNode) const;
// Return the parameters sequence from the function definition or prototype.
TIntermSequence *getFunctionParameters(const TIntermAggregate *callNode);
// Track whether an l-value is required inside a function call. // Track whether an l-value is required inside a function call.
void setInFunctionCallOutParameter(bool inOutParameter); void setInFunctionCallOutParameter(bool inOutParameter);
bool isInFunctionCallOutParameter() const; bool isInFunctionCallOutParameter() const;
bool mOperatorRequiresLValue; bool mOperatorRequiresLValue;
bool mInFunctionCallOutParameter; bool mInFunctionCallOutParameter;
// Map from function symbol id values to their parameter sequences
TMap<int, TIntermSequence *> mFunctionMap;
const int mShaderVersion;
}; };
} // namespace sh } // namespace sh
......
...@@ -5854,7 +5854,7 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(const TString &name, ...@@ -5854,7 +5854,7 @@ TIntermTyped *TParseContext::addNonConstructorFunctionCall(const TString &name,
else else
{ {
TIntermAggregate *callNode = TIntermAggregate *callNode =
TIntermAggregate::Create(fnCandidate->getReturnType(), op, arguments); TIntermAggregate::Create(*fnCandidate, op, arguments);
callNode->setLine(loc); callNode->setLine(loc);
// Some built-in functions have out parameters too. // Some built-in functions have out parameters too.
......
...@@ -23,6 +23,8 @@ namespace sh ...@@ -23,6 +23,8 @@ namespace sh
namespace namespace
{ {
const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqIn, 1, 1>();
std::string GetIndexFunctionName(const TType &type, bool write) std::string GetIndexFunctionName(const TType &type, bool write)
{ {
TInfoSinkBase nameSink; TInfoSinkBase nameSink;
...@@ -59,31 +61,11 @@ std::string GetIndexFunctionName(const TType &type, bool write) ...@@ -59,31 +61,11 @@ std::string GetIndexFunctionName(const TType &type, bool write)
return nameSink.str(); return nameSink.str();
} }
TIntermSymbol *CreateBaseSymbol(const TType *type, TSymbolTable *symbolTable) TIntermSymbol *CreateParameterSymbol(const TConstParameter &parameter, TSymbolTable *symbolTable)
{
TString *baseString = NewPoolTString("base");
TVariable *baseVariable =
new TVariable(symbolTable, baseString, type, SymbolType::AngleInternal);
return new TIntermSymbol(baseVariable);
}
TIntermSymbol *CreateIndexSymbol(TSymbolTable *symbolTable)
{ {
TString *indexString = NewPoolTString("index"); TVariable *variable =
TVariable *indexVariable = new TVariable(symbolTable, parameter.name, parameter.type, SymbolType::AngleInternal);
new TVariable(symbolTable, indexString, StaticType::Get<EbtInt, EbpHigh, EvqIn, 1, 1>(), return new TIntermSymbol(variable);
SymbolType::AngleInternal);
return new TIntermSymbol(indexVariable);
}
TIntermSymbol *CreateValueSymbol(const TType &type, TSymbolTable *symbolTable)
{
TString *valueString = NewPoolTString("value");
TType *valueType = new TType(type);
valueType->setQualifier(EvqIn);
TVariable *valueVariable =
new TVariable(symbolTable, valueString, valueType, SymbolType::AngleInternal);
return new TIntermSymbol(valueVariable);
} }
TIntermConstantUnion *CreateIntConstantNode(int i) TIntermConstantUnion *CreateIntConstantNode(int i)
...@@ -117,6 +99,20 @@ TType *GetFieldType(const TType &indexedType) ...@@ -117,6 +99,20 @@ TType *GetFieldType(const TType &indexedType)
} }
} }
const TType *GetBaseType(const TType &type, bool write)
{
TType *baseType = new TType(type);
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
// end up using mediump version of an indexing function for a highp value, if both mediump and
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends.
baseType->setPrecision(EbpHigh);
baseType->setQualifier(EvqInOut);
if (!write)
baseType->setQualifier(EvqIn);
return baseType;
}
// Generate a read or write function for one field in a vector/matrix. // Generate a read or write function for one field in a vector/matrix.
// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
// indices in other places. // indices in other places.
...@@ -167,7 +163,6 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type, ...@@ -167,7 +163,6 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
{ {
ASSERT(!type.isArray()); ASSERT(!type.isArray());
const TType *fieldType = GetFieldType(type);
int numCases = 0; int numCases = 0;
if (type.isMatrix()) if (type.isMatrix())
{ {
...@@ -181,24 +176,14 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type, ...@@ -181,24 +176,14 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
std::string functionName = GetIndexFunctionName(type, write); std::string functionName = GetIndexFunctionName(type, write);
TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func); TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
TType *baseType = new TType(type); TIntermSymbol *baseParam = CreateParameterSymbol(func.getParam(0), symbolTable);
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
// end up using mediump version of an indexing function for a highp value, if both mediump and
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends.
baseType->setPrecision(EbpHigh);
baseType->setQualifier(EvqInOut);
if (!write)
baseType->setQualifier(EvqIn);
TIntermSymbol *baseParam = CreateBaseSymbol(baseType, symbolTable);
prototypeNode->getSequence()->push_back(baseParam); prototypeNode->getSequence()->push_back(baseParam);
TIntermSymbol *indexParam = CreateIndexSymbol(symbolTable); TIntermSymbol *indexParam = CreateParameterSymbol(func.getParam(1), symbolTable);
prototypeNode->getSequence()->push_back(indexParam); prototypeNode->getSequence()->push_back(indexParam);
TIntermSymbol *valueParam = nullptr; TIntermSymbol *valueParam = nullptr;
if (write) if (write)
{ {
valueParam = CreateValueSymbol(*fieldType, symbolTable); valueParam = CreateParameterSymbol(func.getParam(2), symbolTable);
prototypeNode->getSequence()->push_back(valueParam); prototypeNode->getSequence()->push_back(valueParam);
} }
...@@ -280,7 +265,6 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser ...@@ -280,7 +265,6 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
{ {
public: public:
RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable, RemoveDynamicIndexingTraverser(TSymbolTable *symbolTable,
int shaderVersion,
PerformanceDiagnostics *perfDiagnostics); PerformanceDiagnostics *perfDiagnostics);
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
...@@ -307,16 +291,22 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser ...@@ -307,16 +291,22 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
bool mRemoveIndexSideEffectsInSubtree; bool mRemoveIndexSideEffectsInSubtree;
PerformanceDiagnostics *mPerfDiagnostics; PerformanceDiagnostics *mPerfDiagnostics;
const TString *mBaseName;
const TString *mIndexName;
const TString *mValueName;
}; };
RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser( RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
PerformanceDiagnostics *perfDiagnostics) PerformanceDiagnostics *perfDiagnostics)
: TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion), : TLValueTrackingTraverser(true, false, false, symbolTable),
mUsedTreeInsertion(false), mUsedTreeInsertion(false),
mRemoveIndexSideEffectsInSubtree(false), mRemoveIndexSideEffectsInSubtree(false),
mPerfDiagnostics(perfDiagnostics) mPerfDiagnostics(perfDiagnostics),
mBaseName(NewPoolTString("base")),
mIndexName(NewPoolTString("index")),
mValueName(NewPoolTString("value"))
{ {
} }
...@@ -422,6 +412,9 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -422,6 +412,9 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
indexingFunction = indexingFunction =
new TFunction(mSymbolTable, indexingFunctionName, GetFieldType(type), new TFunction(mSymbolTable, indexingFunctionName, GetFieldType(type),
SymbolType::AngleInternal, true); SymbolType::AngleInternal, true);
indexingFunction->addParameter(
TConstParameter(mBaseName, GetBaseType(type, false)));
indexingFunction->addParameter(TConstParameter(mIndexName, kIndexType));
mIndexedVecAndMatrixTypes[type] = indexingFunction; mIndexedVecAndMatrixTypes[type] = indexingFunction;
} }
else else
...@@ -467,8 +460,15 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod ...@@ -467,8 +460,15 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
TString *functionName = NewPoolTString( TString *functionName = NewPoolTString(
GetIndexFunctionName(node->getLeft()->getType(), true).c_str()); GetIndexFunctionName(node->getLeft()->getType(), true).c_str());
indexedWriteFunction = indexedWriteFunction =
new TFunction(mSymbolTable, functionName, new TType(EbtVoid), new TFunction(mSymbolTable, functionName, StaticType::GetBasic<EbtVoid>(),
SymbolType::AngleInternal, false); SymbolType::AngleInternal, false);
indexedWriteFunction->addParameter(
TConstParameter(mBaseName, GetBaseType(type, true)));
indexedWriteFunction->addParameter(TConstParameter(mIndexName, kIndexType));
TType *valueType = GetFieldType(type);
valueType->setQualifier(EvqIn);
indexedWriteFunction->addParameter(
TConstParameter(mValueName, static_cast<const TType *>(valueType)));
mWrittenVecAndMatrixTypes[type] = indexedWriteFunction; mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
} }
else else
...@@ -532,10 +532,9 @@ void RemoveDynamicIndexingTraverser::nextIteration() ...@@ -532,10 +532,9 @@ void RemoveDynamicIndexingTraverser::nextIteration()
void RemoveDynamicIndexing(TIntermNode *root, void RemoveDynamicIndexing(TIntermNode *root,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
PerformanceDiagnostics *perfDiagnostics) PerformanceDiagnostics *perfDiagnostics)
{ {
RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion, perfDiagnostics); RemoveDynamicIndexingTraverser traverser(symbolTable, perfDiagnostics);
do do
{ {
traverser.nextIteration(); traverser.nextIteration();
......
...@@ -19,7 +19,6 @@ class PerformanceDiagnostics; ...@@ -19,7 +19,6 @@ class PerformanceDiagnostics;
void RemoveDynamicIndexing(TIntermNode *root, void RemoveDynamicIndexing(TIntermNode *root,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
PerformanceDiagnostics *perfDiagnostics); PerformanceDiagnostics *perfDiagnostics);
} // namespace sh } // namespace sh
......
...@@ -25,8 +25,7 @@ class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser ...@@ -25,8 +25,7 @@ class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
{ {
public: public:
SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask, SimplifyLoopConditionsTraverser(unsigned int conditionsToSimplifyMask,
TSymbolTable *symbolTable, TSymbolTable *symbolTable);
int shaderVersion);
void traverseLoop(TIntermLoop *node) override; void traverseLoop(TIntermLoop *node) override;
...@@ -48,9 +47,8 @@ class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser ...@@ -48,9 +47,8 @@ class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser( SimplifyLoopConditionsTraverser::SimplifyLoopConditionsTraverser(
unsigned int conditionsToSimplifyMask, unsigned int conditionsToSimplifyMask,
TSymbolTable *symbolTable, TSymbolTable *symbolTable)
int shaderVersion) : TLValueTrackingTraverser(true, false, false, symbolTable),
: TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
mFoundLoopToChange(false), mFoundLoopToChange(false),
mInsideLoopInitConditionOrExpression(false), mInsideLoopInitConditionOrExpression(false),
mConditionsToSimplify(conditionsToSimplifyMask) mConditionsToSimplify(conditionsToSimplifyMask)
...@@ -292,10 +290,9 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node) ...@@ -292,10 +290,9 @@ void SimplifyLoopConditionsTraverser::traverseLoop(TIntermLoop *node)
void SimplifyLoopConditions(TIntermNode *root, void SimplifyLoopConditions(TIntermNode *root,
unsigned int conditionsToSimplifyMask, unsigned int conditionsToSimplifyMask,
TSymbolTable *symbolTable, TSymbolTable *symbolTable)
int shaderVersion)
{ {
SimplifyLoopConditionsTraverser traverser(conditionsToSimplifyMask, symbolTable, shaderVersion); SimplifyLoopConditionsTraverser traverser(conditionsToSimplifyMask, symbolTable);
root->traverse(&traverser); root->traverse(&traverser);
traverser.updateTree(); traverser.updateTree();
} }
......
...@@ -18,8 +18,7 @@ class TSymbolTable; ...@@ -18,8 +18,7 @@ class TSymbolTable;
void SimplifyLoopConditions(TIntermNode *root, void SimplifyLoopConditions(TIntermNode *root,
unsigned int conditionsToSimplify, unsigned int conditionsToSimplify,
TSymbolTable *symbolTable, TSymbolTable *symbolTable);
int shaderVersion);
} // namespace sh } // namespace sh
#endif // COMPILER_TRANSLATOR_SIMPLIFYLOOPCONDITIONS_H_ #endif // COMPILER_TRANSLATOR_SIMPLIFYLOOPCONDITIONS_H_
...@@ -23,9 +23,7 @@ namespace ...@@ -23,9 +23,7 @@ namespace
class SplitSequenceOperatorTraverser : public TLValueTrackingTraverser class SplitSequenceOperatorTraverser : public TLValueTrackingTraverser
{ {
public: public:
SplitSequenceOperatorTraverser(unsigned int patternsToSplitMask, SplitSequenceOperatorTraverser(unsigned int patternsToSplitMask, TSymbolTable *symbolTable);
TSymbolTable *symbolTable,
int shaderVersion);
bool visitUnary(Visit visit, TIntermUnary *node) override; bool visitUnary(Visit visit, TIntermUnary *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
...@@ -45,9 +43,8 @@ class SplitSequenceOperatorTraverser : public TLValueTrackingTraverser ...@@ -45,9 +43,8 @@ class SplitSequenceOperatorTraverser : public TLValueTrackingTraverser
}; };
SplitSequenceOperatorTraverser::SplitSequenceOperatorTraverser(unsigned int patternsToSplitMask, SplitSequenceOperatorTraverser::SplitSequenceOperatorTraverser(unsigned int patternsToSplitMask,
TSymbolTable *symbolTable, TSymbolTable *symbolTable)
int shaderVersion) : TLValueTrackingTraverser(true, false, true, symbolTable),
: TLValueTrackingTraverser(true, false, true, symbolTable, shaderVersion),
mFoundExpressionToSplit(false), mFoundExpressionToSplit(false),
mInsideSequenceOperator(0), mInsideSequenceOperator(0),
mPatternToSplitMatcher(patternsToSplitMask) mPatternToSplitMatcher(patternsToSplitMask)
...@@ -151,12 +148,9 @@ bool SplitSequenceOperatorTraverser::visitTernary(Visit visit, TIntermTernary *n ...@@ -151,12 +148,9 @@ bool SplitSequenceOperatorTraverser::visitTernary(Visit visit, TIntermTernary *n
} // namespace } // namespace
void SplitSequenceOperator(TIntermNode *root, void SplitSequenceOperator(TIntermNode *root, int patternsToSplitMask, TSymbolTable *symbolTable)
int patternsToSplitMask,
TSymbolTable *symbolTable,
int shaderVersion)
{ {
SplitSequenceOperatorTraverser traverser(patternsToSplitMask, symbolTable, shaderVersion); SplitSequenceOperatorTraverser traverser(patternsToSplitMask, symbolTable);
// Separate one expression at a time, and reset the traverser between iterations. // Separate one expression at a time, and reset the traverser between iterations.
do do
{ {
......
...@@ -18,10 +18,7 @@ namespace sh ...@@ -18,10 +18,7 @@ namespace sh
class TIntermNode; class TIntermNode;
class TSymbolTable; class TSymbolTable;
void SplitSequenceOperator(TIntermNode *root, void SplitSequenceOperator(TIntermNode *root, int patternsToSplitMask, TSymbolTable *symbolTable);
int patternsToSplitMask,
TSymbolTable *symbolTable,
int shaderVersion);
} // namespace sh } // namespace sh
......
...@@ -135,18 +135,13 @@ TFunction::TFunction(TSymbolTable *symbolTable, ...@@ -135,18 +135,13 @@ TFunction::TFunction(TSymbolTable *symbolTable,
ASSERT(name != nullptr || symbolType == SymbolType::AngleInternal || tOp != EOpNull); ASSERT(name != nullptr || symbolType == SymbolType::AngleInternal || tOp != EOpNull);
} }
//
// Functions have buried pointers to delete.
//
TFunction::~TFunction() TFunction::~TFunction()
{ {
clearParameters(); // Just here to discourage the compiler from inlining it.
} }
void TFunction::clearParameters() void TFunction::clearParameters()
{ {
for (TParamList::iterator i = parameters.begin(); i != parameters.end(); ++i)
delete (*i).type;
parameters.clear(); parameters.clear();
mangledName = nullptr; mangledName = nullptr;
} }
......
...@@ -25,7 +25,6 @@ class TSymbolTable::TSymbolTableLevel ...@@ -25,7 +25,6 @@ class TSymbolTable::TSymbolTableLevel
{ {
public: public:
TSymbolTableLevel() : mGlobalInvariant(false) {} TSymbolTableLevel() : mGlobalInvariant(false) {}
~TSymbolTableLevel();
bool insert(TSymbol *symbol); bool insert(TSymbol *symbol);
...@@ -62,15 +61,6 @@ class TSymbolTable::TSymbolTableLevel ...@@ -62,15 +61,6 @@ class TSymbolTable::TSymbolTableLevel
std::set<const char *, CharArrayComparator> mUnmangledBuiltInNames; std::set<const char *, CharArrayComparator> mUnmangledBuiltInNames;
}; };
//
// Symbol table levels are a map of pointers to symbols that have to be deleted.
//
TSymbolTable::TSymbolTableLevel::~TSymbolTableLevel()
{
for (tLevel::iterator it = level.begin(); it != level.end(); ++it)
delete (*it).second;
}
bool TSymbolTable::TSymbolTableLevel::insert(TSymbol *symbol) bool TSymbolTable::TSymbolTableLevel::insert(TSymbol *symbol)
{ {
// returning true means symbol was added to the table // returning true means symbol was added to the table
......
...@@ -53,7 +53,7 @@ void TranslatorESSL::translate(TIntermBlock *root, ...@@ -53,7 +53,7 @@ void TranslatorESSL::translate(TIntermBlock *root,
if (precisionEmulation) if (precisionEmulation)
{ {
EmulatePrecision emulatePrecision(&getSymbolTable(), shaderVer); EmulatePrecision emulatePrecision(&getSymbolTable());
root->traverse(&emulatePrecision); root->traverse(&emulatePrecision);
emulatePrecision.updateTree(); emulatePrecision.updateTree();
emulatePrecision.writeEmulationHelpers(sink, shaderVer, SH_ESSL_OUTPUT); emulatePrecision.writeEmulationHelpers(sink, shaderVer, SH_ESSL_OUTPUT);
......
...@@ -109,7 +109,7 @@ void TranslatorGLSL::translate(TIntermBlock *root, ...@@ -109,7 +109,7 @@ void TranslatorGLSL::translate(TIntermBlock *root,
if (precisionEmulation) if (precisionEmulation)
{ {
EmulatePrecision emulatePrecision(&getSymbolTable(), getShaderVersion()); EmulatePrecision emulatePrecision(&getSymbolTable());
root->traverse(&emulatePrecision); root->traverse(&emulatePrecision);
emulatePrecision.updateTree(); emulatePrecision.updateTree();
emulatePrecision.writeEmulationHelpers(sink, getShaderVersion(), getOutputType()); emulatePrecision.writeEmulationHelpers(sink, getShaderVersion(), getOutputType());
......
...@@ -51,13 +51,13 @@ void TranslatorHLSL::translate(TIntermBlock *root, ...@@ -51,13 +51,13 @@ void TranslatorHLSL::translate(TIntermBlock *root,
IntermNodePatternMatcher::kExpressionReturningArray | IntermNodePatternMatcher::kExpressionReturningArray |
IntermNodePatternMatcher::kUnfoldedShortCircuitExpression | IntermNodePatternMatcher::kUnfoldedShortCircuitExpression |
IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue, IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue,
&getSymbolTable(), getShaderVersion()); &getSymbolTable());
SplitSequenceOperator(root, SplitSequenceOperator(root,
IntermNodePatternMatcher::kExpressionReturningArray | IntermNodePatternMatcher::kExpressionReturningArray |
IntermNodePatternMatcher::kUnfoldedShortCircuitExpression | IntermNodePatternMatcher::kUnfoldedShortCircuitExpression |
IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue, IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue,
&getSymbolTable(), getShaderVersion()); &getSymbolTable());
// Note that SeparateDeclarations needs to be run before UnfoldShortCircuitToIf. // Note that SeparateDeclarations needs to be run before UnfoldShortCircuitToIf.
UnfoldShortCircuitToIf(root, &getSymbolTable()); UnfoldShortCircuitToIf(root, &getSymbolTable());
...@@ -76,7 +76,7 @@ void TranslatorHLSL::translate(TIntermBlock *root, ...@@ -76,7 +76,7 @@ void TranslatorHLSL::translate(TIntermBlock *root,
if (!shouldRunLoopAndIndexingValidation(compileOptions)) if (!shouldRunLoopAndIndexingValidation(compileOptions))
{ {
// HLSL doesn't support dynamic indexing of vectors and matrices. // HLSL doesn't support dynamic indexing of vectors and matrices.
RemoveDynamicIndexing(root, &getSymbolTable(), getShaderVersion(), perfDiagnostics); RemoveDynamicIndexing(root, &getSymbolTable(), perfDiagnostics);
} }
// Work around D3D9 bug that would manifest in vertex shaders with selection blocks which // Work around D3D9 bug that would manifest in vertex shaders with selection blocks which
...@@ -106,7 +106,7 @@ void TranslatorHLSL::translate(TIntermBlock *root, ...@@ -106,7 +106,7 @@ void TranslatorHLSL::translate(TIntermBlock *root,
if (precisionEmulation) if (precisionEmulation)
{ {
EmulatePrecision emulatePrecision(&getSymbolTable(), getShaderVersion()); EmulatePrecision emulatePrecision(&getSymbolTable());
root->traverse(&emulatePrecision); root->traverse(&emulatePrecision);
emulatePrecision.updateTree(); emulatePrecision.updateTree();
emulatePrecision.writeEmulationHelpers(getInfoSink().obj, getShaderVersion(), emulatePrecision.writeEmulationHelpers(getInfoSink().obj, getShaderVersion(),
......
...@@ -72,7 +72,6 @@ class ValidateLimitationsTraverser : public TLValueTrackingTraverser ...@@ -72,7 +72,6 @@ class ValidateLimitationsTraverser : public TLValueTrackingTraverser
public: public:
ValidateLimitationsTraverser(sh::GLenum shaderType, ValidateLimitationsTraverser(sh::GLenum shaderType,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
TDiagnostics *diagnostics); TDiagnostics *diagnostics);
void visitSymbol(TIntermSymbol *node) override; void visitSymbol(TIntermSymbol *node) override;
...@@ -104,9 +103,8 @@ class ValidateLimitationsTraverser : public TLValueTrackingTraverser ...@@ -104,9 +103,8 @@ class ValidateLimitationsTraverser : public TLValueTrackingTraverser
ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType, ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
TDiagnostics *diagnostics) TDiagnostics *diagnostics)
: TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion), : TLValueTrackingTraverser(true, false, false, symbolTable),
mShaderType(shaderType), mShaderType(shaderType),
mDiagnostics(diagnostics) mDiagnostics(diagnostics)
{ {
...@@ -427,10 +425,9 @@ bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node) ...@@ -427,10 +425,9 @@ bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
bool ValidateLimitations(TIntermNode *root, bool ValidateLimitations(TIntermNode *root,
GLenum shaderType, GLenum shaderType,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
TDiagnostics *diagnostics) TDiagnostics *diagnostics)
{ {
ValidateLimitationsTraverser validate(shaderType, symbolTable, shaderVersion, diagnostics); ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
root->traverse(&validate); root->traverse(&validate);
return diagnostics->numErrors() == 0; return diagnostics->numErrors() == 0;
} }
......
...@@ -19,7 +19,6 @@ class TDiagnostics; ...@@ -19,7 +19,6 @@ class TDiagnostics;
bool ValidateLimitations(TIntermNode *root, bool ValidateLimitations(TIntermNode *root,
GLenum shaderType, GLenum shaderType,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int shaderVersion,
TDiagnostics *diagnostics); TDiagnostics *diagnostics);
} // namespace sh } // namespace sh
......
...@@ -115,3 +115,44 @@ TEST_F(HLSLOutputTest, ArrayOfArraysStatement) ...@@ -115,3 +115,44 @@ TEST_F(HLSLOutputTest, ArrayOfArraysStatement)
})"; })";
compile(shaderString); compile(shaderString);
} }
// Test dynamic indexing of a vector. This makes sure that helper functions added for dynamic
// indexing have correct data that subsequent traversal steps rely on.
TEST_F(HLSLOutputTest, VectorDynamicIndexing)
{
const std::string &shaderString =
R"(#version 300 es
precision mediump float;
out vec4 outColor;
uniform int i;
void main()
{
vec4 foo = vec4(0.0, 0.0, 0.0, 1.0);
foo[i] = foo[i + 1];
outColor = foo;
})";
compile(shaderString);
}
// Test returning an array from a user-defined function. This makes sure that function symbols are
// changed consistently when the user-defined function is changed to have an array out parameter.
TEST_F(HLSLOutputTest, ArrayReturnValue)
{
const std::string &shaderString =
R"(#version 300 es
precision mediump float;
uniform float u;
out vec4 outColor;
float[2] getArray(float f)
{
return float[2](f, f + 1.0);
}
void main()
{
float[2] arr = getArray(u);
outColor = vec4(arr[0], arr[1], 0.0, 1.0);
})";
compile(shaderString);
}
...@@ -60,6 +60,16 @@ class IntermNodeTest : public testing::Test ...@@ -60,6 +60,16 @@ class IntermNodeTest : public testing::Test
return createTestSymbol(type); return createTestSymbol(type);
} }
TFunction *createTestBuiltInFunction(const TType &returnType, const TIntermSequence &args)
{
// We're using a dummy symbol table similarly as for creating symbol nodes.
TString *name = NewPoolTString("testFunc");
TSymbolTable symbolTable;
TFunction *func =
new TFunction(&symbolTable, name, new TType(returnType), SymbolType::BuiltIn, true);
return func;
}
void checkTypeEqualWithQualifiers(const TType &original, const TType &copy) void checkTypeEqualWithQualifiers(const TType &original, const TType &copy)
{ {
ASSERT_EQ(original, copy); ASSERT_EQ(original, copy);
...@@ -200,8 +210,11 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode) ...@@ -200,8 +210,11 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode)
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
originalSeq->push_back(createTestSymbol()); originalSeq->push_back(createTestSymbol());
TIntermAggregate *original =
TIntermAggregate::Create(originalSeq->at(0)->getAsTyped()->getType(), EOpMix, originalSeq); TFunction *mix =
createTestBuiltInFunction(originalSeq->back()->getAsTyped()->getType(), *originalSeq);
TIntermAggregate *original = TIntermAggregate::Create(*mix, EOpMix, originalSeq);
original->setLine(getTestSourceLoc()); original->setLine(getTestSourceLoc());
TIntermTyped *copyTyped = original->deepCopy(); TIntermTyped *copyTyped = original->deepCopy();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment