Commit ec9232bd by Olli Etuaho Committed by Commit Bot

Store unmangled function names in the AST

This makes the code simpler across the board. There are a few cases where mangled names still need to be generated in AST traversers, but they are outweighed by much leaner output code for all function nodes. BUG=angleproject:1490 TEST=angle_unittests, angle_end2end_tests Change-Id: Id3638e0fca6019bbbe6fc5e1b7763870591da2d8 Reviewed-on: https://chromium-review.googlesource.com/461077 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent 495bd776
...@@ -128,9 +128,8 @@ class PullGradient : public TIntermTraverser ...@@ -128,9 +128,8 @@ class PullGradient : public TIntermTraverser
} }
else if (node->getOp() == EOpCallBuiltInFunction) else if (node->getOp() == EOpCallBuiltInFunction)
{ {
TString name = TFunction::unmangleName(node->getFunctionSymbolInfo()->getName()); if (mGradientBuiltinFunctions.find(node->getFunctionSymbolInfo()->getName()) !=
mGradientBuiltinFunctions.end())
if (mGradientBuiltinFunctions.find(name) != mGradientBuiltinFunctions.end())
{ {
onGradient(); onGradient();
} }
......
...@@ -760,7 +760,7 @@ bool TCompiler::tagUsedFunctions() ...@@ -760,7 +760,7 @@ bool TCompiler::tagUsedFunctions()
// Search from main, starting from the end of the DAG as it usually is the root. // Search from main, starting from the end of the DAG as it usually is the root.
for (size_t i = mCallDag.size(); i-- > 0;) for (size_t i = mCallDag.size(); i-- > 0;)
{ {
if (mCallDag.getRecordFromIndex(i).name == "main(") if (mCallDag.getRecordFromIndex(i).name == "main")
{ {
internalTagUsedFunction(i); internalTagUsedFunction(i);
return true; return true;
......
...@@ -431,7 +431,7 @@ TIntermAggregate *createInternalFunctionCallNode(const TType &type, ...@@ -431,7 +431,7 @@ TIntermAggregate *createInternalFunctionCallNode(const TType &type,
TString name, TString name,
TIntermSequence *arguments) TIntermSequence *arguments)
{ {
TName nameObj(TFunction::GetMangledNameFromCall(name, *arguments)); TName nameObj(name);
nameObj.setInternal(true); nameObj.setInternal(true);
TIntermAggregate *callNode = TIntermAggregate *callNode =
TIntermAggregate::Create(type, EOpCallInternalRawFunction, arguments); TIntermAggregate::Create(type, EOpCallInternalRawFunction, arguments);
......
...@@ -630,7 +630,7 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T ...@@ -630,7 +630,7 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T
void TFunctionSymbolInfo::setFromFunction(const TFunction &function) void TFunctionSymbolInfo::setFromFunction(const TFunction &function)
{ {
setName(function.getMangledName()); setName(function.getName());
setId(TSymbolUniqueId(function)); setId(TSymbolUniqueId(function));
} }
...@@ -3379,7 +3379,6 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent, ...@@ -3379,7 +3379,6 @@ void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
TName TIntermTraverser::GetInternalFunctionName(const char *name) TName TIntermTraverser::GetInternalFunctionName(const char *name)
{ {
TString nameStr(name); TString nameStr(name);
nameStr = TFunction::mangleName(nameStr);
TName nameObj(nameStr); TName nameObj(nameStr);
nameObj.setInternal(true); nameObj.setInternal(true);
return nameObj; return nameObj;
......
...@@ -550,7 +550,7 @@ class TFunctionSymbolInfo ...@@ -550,7 +550,7 @@ class TFunctionSymbolInfo
const TString &getName() const { return mName.getString(); } const TString &getName() const { return mName.getString(); }
void setName(const TString &name) { mName.setString(name); } void setName(const TString &name) { mName.setString(name); }
bool isMain() const { return mName.getString() == "main("; } bool isMain() const { return mName.getString() == "main"; }
void setId(const TSymbolUniqueId &functionId); void setId(const TSymbolUniqueId &functionId);
const TSymbolUniqueId &getId() const; const TSymbolUniqueId &getId() const;
......
...@@ -66,7 +66,7 @@ void TOutputGLSL::visitSymbol(TIntermSymbol *node) ...@@ -66,7 +66,7 @@ void TOutputGLSL::visitSymbol(TIntermSymbol *node)
} }
} }
TString TOutputGLSL::translateTextureFunction(TString &name) TString TOutputGLSL::translateTextureFunction(const TString &name)
{ {
static const char *simpleRename[] = {"texture2DLodEXT", static const char *simpleRename[] = {"texture2DLodEXT",
"texture2DLod", "texture2DLod",
......
...@@ -28,7 +28,7 @@ class TOutputGLSL : public TOutputGLSLBase ...@@ -28,7 +28,7 @@ class TOutputGLSL : public TOutputGLSLBase
protected: protected:
bool writeVariablePrecision(TPrecision) override; bool writeVariablePrecision(TPrecision) override;
void visitSymbol(TIntermSymbol *node) override; void visitSymbol(TIntermSymbol *node) override;
TString translateTextureFunction(TString &name) override; TString translateTextureFunction(const TString &name) override;
}; };
} // namespace sh } // namespace sh
......
...@@ -926,7 +926,7 @@ bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototy ...@@ -926,7 +926,7 @@ bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototy
if (type.isArray()) if (type.isArray())
out << arrayBrackets(type); out << arrayBrackets(type);
out << " " << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj()); out << " " << hashFunctionNameIfNeeded(*node->getFunctionSymbolInfo());
out << "("; out << "(";
writeFunctionParameters(*(node->getSequence())); writeFunctionParameters(*(node->getSequence()));
...@@ -946,7 +946,17 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -946,7 +946,17 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpCallBuiltInFunction: case EOpCallBuiltInFunction:
// Function call. // Function call.
if (visit == PreVisit) if (visit == PreVisit)
out << hashFunctionNameIfNeeded(node->getFunctionSymbolInfo()->getNameObj()) << "("; {
if (node->getOp() == EOpCallBuiltInFunction)
{
out << translateTextureFunction(node->getFunctionSymbolInfo()->getName());
}
else
{
out << hashFunctionNameIfNeeded(*node->getFunctionSymbolInfo());
}
out << "(";
}
else if (visit == InVisit) else if (visit == InVisit)
out << ", "; out << ", ";
else else
...@@ -1196,22 +1206,17 @@ TString TOutputGLSLBase::hashVariableName(const TName &name) ...@@ -1196,22 +1206,17 @@ TString TOutputGLSLBase::hashVariableName(const TName &name)
return hashName(name); return hashName(name);
} }
TString TOutputGLSLBase::hashFunctionNameIfNeeded(const TName &mangledName) TString TOutputGLSLBase::hashFunctionNameIfNeeded(const TFunctionSymbolInfo &info)
{ {
TString mangledStr = mangledName.getString(); if (info.isMain() || info.getNameObj().isInternal())
TString name = TFunction::unmangleName(mangledStr);
if (mSymbolTable.findBuiltIn(mangledStr, mShaderVersion) != nullptr || name == "main")
return translateTextureFunction(name);
if (mangledName.isInternal())
{ {
// Internal function names are outputted as-is - they may refer to functions manually added // Internal function names are outputted as-is - they may refer to functions manually added
// to the output shader source that are not included in the AST at all. // to the output shader source that are not included in the AST at all.
return name; return info.getName();
} }
else else
{ {
TName nameObj(name); return hashName(info.getNameObj());
return hashName(nameObj);
} }
} }
......
...@@ -70,10 +70,10 @@ class TOutputGLSLBase : public TIntermTraverser ...@@ -70,10 +70,10 @@ class TOutputGLSLBase : public TIntermTraverser
// Same as hashName(), but without hashing built-in variables. // Same as hashName(), but without hashing built-in variables.
TString hashVariableName(const TName &name); TString hashVariableName(const TName &name);
// Same as hashName(), but without hashing built-in functions and with unmangling. // Same as hashName(), but without hashing internal functions or "main".
TString hashFunctionNameIfNeeded(const TName &mangledName); TString hashFunctionNameIfNeeded(const TFunctionSymbolInfo &info);
// Used to translate function names for differences between ESSL and GLSL // Used to translate function names for differences between ESSL and GLSL
virtual TString translateTextureFunction(TString &name) { return name; } virtual TString translateTextureFunction(const TString &name) { return name; }
private: private:
bool structDeclared(const TStructure *structure) const; bool structDeclared(const TStructure *structure) const;
......
...@@ -1590,7 +1590,7 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition ...@@ -1590,7 +1590,7 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition
} }
else else
{ {
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()) out << DecorateIfNeeded(node->getFunctionSymbolInfo()->getNameObj())
<< DisambiguateFunctionName(parameters) << (mOutputLod0Function ? "Lod0(" : "("); << DisambiguateFunctionName(parameters) << (mOutputLod0Function ? "Lod0(" : "(");
} }
...@@ -1722,7 +1722,7 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n ...@@ -1722,7 +1722,7 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n
TIntermSequence *arguments = node->getSequence(); TIntermSequence *arguments = node->getSequence();
TString name = DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()); TString name = DecorateIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments) out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments)
<< (mOutputLod0Function ? "Lod0(" : "("); << (mOutputLod0Function ? "Lod0(" : "(");
...@@ -1776,7 +1776,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1776,7 +1776,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
ASSERT(index != CallDAG::InvalidIndex); ASSERT(index != CallDAG::InvalidIndex);
lod0 &= mASTMetadataList[index].mNeedsLod0; lod0 &= mASTMetadataList[index].mNeedsLod0;
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()); out << DecorateIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
out << DisambiguateFunctionName(node->getSequence()); out << DisambiguateFunctionName(node->getSequence());
out << (lod0 ? "Lod0(" : "("); out << (lod0 ? "Lod0(" : "(");
} }
...@@ -1784,11 +1784,11 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1784,11 +1784,11 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
// This path is used for internal functions that don't have their definitions in the // This path is used for internal functions that don't have their definitions in the
// AST, such as precision emulation functions. // AST, such as precision emulation functions.
out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()) << "("; out << DecorateIfNeeded(node->getFunctionSymbolInfo()->getNameObj()) << "(";
} }
else else
{ {
TString name = TFunction::unmangleName(node->getFunctionSymbolInfo()->getName()); const TString &name = node->getFunctionSymbolInfo()->getName();
TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType(); TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
int coords = 0; // textureSize(gsampler2DMS) doesn't have a second argument. int coords = 0; // textureSize(gsampler2DMS) doesn't have a second argument.
if (arguments->size() > 1) if (arguments->size() > 1)
......
...@@ -1443,11 +1443,9 @@ void TParseContext::functionCallLValueErrorCheck(const TFunction *fnCandidate, ...@@ -1443,11 +1443,9 @@ void TParseContext::functionCallLValueErrorCheck(const TFunction *fnCandidate,
TIntermTyped *argument = (*(fnCall->getSequence()))[i]->getAsTyped(); TIntermTyped *argument = (*(fnCall->getSequence()))[i]->getAsTyped();
if (!checkCanBeLValue(argument->getLine(), "assign", argument)) if (!checkCanBeLValue(argument->getLine(), "assign", argument))
{ {
TString unmangledName =
TFunction::unmangleName(fnCall->getFunctionSymbolInfo()->getName());
error(argument->getLine(), error(argument->getLine(),
"Constant value cannot be passed for 'out' or 'inout' parameters.", "Constant value cannot be passed for 'out' or 'inout' parameters.",
unmangledName.c_str()); fnCall->getFunctionSymbolInfo()->getName().c_str());
return; return;
} }
} }
...@@ -4277,16 +4275,13 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall) ...@@ -4277,16 +4275,13 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall)
const TString &name = functionCall->getFunctionSymbolInfo()->getName(); const TString &name = functionCall->getFunctionSymbolInfo()->getName();
TIntermNode *offset = nullptr; TIntermNode *offset = nullptr;
TIntermSequence *arguments = functionCall->getSequence(); TIntermSequence *arguments = functionCall->getSequence();
if (name.compare(0, 16, "texelFetchOffset") == 0 || if (name == "texelFetchOffset" || name == "textureLodOffset" ||
name.compare(0, 16, "textureLodOffset") == 0 || name == "textureProjLodOffset" || name == "textureGradOffset" ||
name.compare(0, 20, "textureProjLodOffset") == 0 || name == "textureProjGradOffset")
name.compare(0, 17, "textureGradOffset") == 0 ||
name.compare(0, 21, "textureProjGradOffset") == 0)
{ {
offset = arguments->back(); offset = arguments->back();
} }
else if (name.compare(0, 13, "textureOffset") == 0 || else if (name == "textureOffset" || name == "textureProjOffset")
name.compare(0, 17, "textureProjOffset") == 0)
{ {
// A bias parameter might follow the offset parameter. // A bias parameter might follow the offset parameter.
ASSERT(arguments->size() >= 3); ASSERT(arguments->size() >= 3);
...@@ -4297,9 +4292,8 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall) ...@@ -4297,9 +4292,8 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall)
TIntermConstantUnion *offsetConstantUnion = offset->getAsConstantUnion(); TIntermConstantUnion *offsetConstantUnion = offset->getAsConstantUnion();
if (offset->getAsTyped()->getQualifier() != EvqConst || !offsetConstantUnion) if (offset->getAsTyped()->getQualifier() != EvqConst || !offsetConstantUnion)
{ {
TString unmangledName = TFunction::unmangleName(name);
error(functionCall->getLine(), "Texture offset must be a constant expression", error(functionCall->getLine(), "Texture offset must be a constant expression",
unmangledName.c_str()); name.c_str());
} }
else else
{ {
......
...@@ -71,7 +71,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -71,7 +71,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
return true; return true;
} }
if (node->getFunctionSymbolInfo()->getName().compare(0, 16, "texelFetchOffset") != 0) if (node->getFunctionSymbolInfo()->getName() != "texelFetchOffset")
{ {
return true; return true;
} }
...@@ -80,16 +80,10 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -80,16 +80,10 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
const TIntermSequence *sequence = node->getSequence(); const TIntermSequence *sequence = node->getSequence();
ASSERT(sequence->size() == 4u); ASSERT(sequence->size() == 4u);
// Decide if there is a 2DArray sampler. // Decide if the sampler is a 2DArray sampler. In that case position is ivec3 and offset is
bool is2DArray = node->getFunctionSymbolInfo()->getName().find("s2a1") != TString::npos; // ivec2.
bool is2DArray = sequence->at(1)->getAsTyped()->getNominalSize() == 3 &&
// Create new argument list from node->getName(). sequence->at(3)->getAsTyped()->getNominalSize() == 2;
// e.g. Get "(is2a1;vi3;i1;" from "texelFetchOffset(is2a1;vi3;i1;vi2;"
TString newArgs = node->getFunctionSymbolInfo()->getName().substr(
16, node->getFunctionSymbolInfo()->getName().length() - 20);
TString newName = "texelFetch" + newArgs;
TSymbol *texelFetchSymbol = symbolTable->findBuiltIn(newName, shaderVersion);
ASSERT(texelFetchSymbol && texelFetchSymbol->isFunction());
// Create new node that represents the call of function texelFetch. // Create new node that represents the call of function texelFetch.
// Its argument list will be: texelFetch(sampler, Position+offset, lod). // Its argument list will be: texelFetch(sampler, Position+offset, lod).
...@@ -135,6 +129,11 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -135,6 +129,11 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
ASSERT(texelFetchArguments->size() == 3u); ASSERT(texelFetchArguments->size() == 3u);
// Get the symbol of the texel fetch function to use.
TString mangledName = TFunction::GetMangledNameFromCall("texelFetch", *texelFetchArguments);
TSymbol *texelFetchSymbol = symbolTable->findBuiltIn(mangledName, shaderVersion);
ASSERT(texelFetchSymbol && texelFetchSymbol->isFunction());
TIntermAggregate *texelFetchNode = TIntermAggregate::CreateBuiltInFunctionCall( TIntermAggregate *texelFetchNode = TIntermAggregate::CreateBuiltInFunctionCall(
*static_cast<const TFunction *>(texelFetchSymbol), texelFetchArguments); *static_cast<const TFunction *>(texelFetchSymbol), texelFetchArguments);
texelFetchNode->setLine(node->getLine()); texelFetchNode->setLine(node->getLine());
......
...@@ -175,10 +175,6 @@ class TFunction : public TSymbol ...@@ -175,10 +175,6 @@ class TFunction : public TSymbol
bool isFunction() const override { return true; } bool isFunction() const override { return true; }
static TString mangleName(const TString &name) { return name + '('; } static TString mangleName(const TString &name) { return name + '('; }
static TString unmangleName(const TString &mangledName)
{
return TString(mangledName.c_str(), mangledName.find_first_of('('));
}
void addParameter(const TConstParameter &p) void addParameter(const TConstParameter &p)
{ {
......
...@@ -241,18 +241,6 @@ TString DecorateIfNeeded(const TName &name) ...@@ -241,18 +241,6 @@ TString DecorateIfNeeded(const TName &name)
} }
} }
TString DecorateFunctionIfNeeded(const TName &name)
{
if (name.isInternal())
{
return TFunction::unmangleName(name.getString());
}
else
{
return Decorate(TFunction::unmangleName(name.getString()));
}
}
TString TypeString(const TType &type) TString TypeString(const TType &type)
{ {
const TStructure *structure = type.getStruct(); const TStructure *structure = type.getStruct();
......
...@@ -65,8 +65,6 @@ TString SamplerString(HLSLTextureSamplerGroup type); ...@@ -65,8 +65,6 @@ TString SamplerString(HLSLTextureSamplerGroup type);
// Prepends an underscore to avoid naming clashes // Prepends an underscore to avoid naming clashes
TString Decorate(const TString &string); TString Decorate(const TString &string);
TString DecorateIfNeeded(const TName &name); TString DecorateIfNeeded(const TName &name);
// Decorates and also unmangles the function name
TString DecorateFunctionIfNeeded(const TName &name);
TString DecorateUniform(const TName &name, const TType &type); TString DecorateUniform(const TName &name, const TType &type);
TString DecorateField(const TString &string, const TStructure &structure); TString DecorateField(const TString &string, const TStructure &structure);
TString DecoratePrivate(const TString &privateText); TString DecoratePrivate(const TString &privateText);
......
...@@ -402,8 +402,11 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node) ...@@ -402,8 +402,11 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node)
bool valid = true; bool valid = true;
TSymbolTable &symbolTable = GetGlobalParseContext()->symbolTable; TSymbolTable &symbolTable = GetGlobalParseContext()->symbolTable;
TSymbol *symbol = symbolTable.find(node->getFunctionSymbolInfo()->getName(), // TODO(oetuaho@nvidia.com): It would be neater to leverage TIntermLValueTrackingTraverser to
GetGlobalParseContext()->getShaderVersion()); // keep track of out parameters, rather than doing a symbol table lookup here.
TString mangledName = TFunction::GetMangledNameFromCall(
node->getFunctionSymbolInfo()->getName(), *node->getSequence());
TSymbol *symbol = symbolTable.find(mangledName, GetGlobalParseContext()->getShaderVersion());
ASSERT(symbol && symbol->isFunction()); ASSERT(symbol && symbol->isFunction());
TFunction *function = static_cast<TFunction *>(symbol); TFunction *function = static_cast<TFunction *>(symbol);
for (ParamIndex::const_iterator i = pIndex.begin(); i != pIndex.end(); ++i) for (ParamIndex::const_iterator i = pIndex.begin(); i != pIndex.end(); ++i)
......
...@@ -363,7 +363,7 @@ bool ValidateMultiviewTraverser::visitAggregate(Visit visit, TIntermAggregate *n ...@@ -363,7 +363,7 @@ bool ValidateMultiviewTraverser::visitAggregate(Visit visit, TIntermAggregate *n
mValid = false; mValid = false;
} }
else if (node->getOp() == EOpCallBuiltInFunction && else if (node->getOp() == EOpCallBuiltInFunction &&
TFunction::unmangleName(node->getFunctionSymbolInfo()->getName()) == "imageStore") node->getFunctionSymbolInfo()->getName() == "imageStore")
{ {
// TODO(oetuaho@nvidia.com): Record which built-in functions have side effects in // TODO(oetuaho@nvidia.com): Record which built-in functions have side effects in
// the symbol info instead. // the symbol info instead.
......
...@@ -75,7 +75,7 @@ class TypeTrackingTest : public testing::Test ...@@ -75,7 +75,7 @@ class TypeTrackingTest : public testing::Test
std::string mInfoLog; std::string mInfoLog;
}; };
TEST_F(TypeTrackingTest, FunctionPrototypeMangling) TEST_F(TypeTrackingTest, FunctionPrototype)
{ {
const std::string &shaderString = const std::string &shaderString =
"precision mediump float;\n" "precision mediump float;\n"
...@@ -90,7 +90,7 @@ TEST_F(TypeTrackingTest, FunctionPrototypeMangling) ...@@ -90,7 +90,7 @@ TEST_F(TypeTrackingTest, FunctionPrototypeMangling)
"}\n"; "}\n";
compile(shaderString); compile(shaderString);
ASSERT_FALSE(foundErrorInIntermediateTree()); ASSERT_FALSE(foundErrorInIntermediateTree());
ASSERT_TRUE(foundInIntermediateTree("Function Prototype: fun(f1;")); ASSERT_TRUE(foundInIntermediateTree("Function Prototype: fun"));
} }
TEST_F(TypeTrackingTest, BuiltInFunctionResultPrecision) TEST_F(TypeTrackingTest, BuiltInFunctionResultPrecision)
...@@ -233,7 +233,7 @@ TEST_F(TypeTrackingTest, Texture2DResultTypeAndPrecision) ...@@ -233,7 +233,7 @@ TEST_F(TypeTrackingTest, Texture2DResultTypeAndPrecision)
"}\n"; "}\n";
compile(shaderString); compile(shaderString);
ASSERT_FALSE(foundErrorInIntermediateTree()); ASSERT_FALSE(foundErrorInIntermediateTree());
ASSERT_TRUE(foundInIntermediateTree("texture2D(s21;vf2; (lowp 4-component vector of float)")); ASSERT_TRUE(foundInIntermediateTree("texture2D (lowp 4-component vector of float)"));
} }
TEST_F(TypeTrackingTest, TextureCubeResultTypeAndPrecision) TEST_F(TypeTrackingTest, TextureCubeResultTypeAndPrecision)
...@@ -250,7 +250,7 @@ TEST_F(TypeTrackingTest, TextureCubeResultTypeAndPrecision) ...@@ -250,7 +250,7 @@ TEST_F(TypeTrackingTest, TextureCubeResultTypeAndPrecision)
"}\n"; "}\n";
compile(shaderString); compile(shaderString);
ASSERT_FALSE(foundErrorInIntermediateTree()); ASSERT_FALSE(foundErrorInIntermediateTree());
ASSERT_TRUE(foundInIntermediateTree("textureCube(sC1;vf3; (lowp 4-component vector of float)")); ASSERT_TRUE(foundInIntermediateTree("textureCube (lowp 4-component vector of float)"));
} }
TEST_F(TypeTrackingTest, TextureSizeResultTypeAndPrecision) TEST_F(TypeTrackingTest, TextureSizeResultTypeAndPrecision)
...@@ -271,7 +271,7 @@ TEST_F(TypeTrackingTest, TextureSizeResultTypeAndPrecision) ...@@ -271,7 +271,7 @@ TEST_F(TypeTrackingTest, TextureSizeResultTypeAndPrecision)
"}\n"; "}\n";
compile(shaderString); compile(shaderString);
ASSERT_FALSE(foundErrorInIntermediateTree()); ASSERT_FALSE(foundErrorInIntermediateTree());
ASSERT_TRUE(foundInIntermediateTree("textureSize(s21;i1; (highp 2-component vector of int)")); ASSERT_TRUE(foundInIntermediateTree("textureSize (highp 2-component vector of int)"));
} }
TEST_F(TypeTrackingTest, BuiltInConstructorResultTypeAndPrecision) TEST_F(TypeTrackingTest, BuiltInConstructorResultTypeAndPrecision)
......
...@@ -20,14 +20,18 @@ namespace ...@@ -20,14 +20,18 @@ namespace
class FunctionCallFinder : public TIntermTraverser class FunctionCallFinder : public TIntermTraverser
{ {
public: public:
FunctionCallFinder(const TString &functionName) FunctionCallFinder(const TString &functionMangledName)
: TIntermTraverser(true, false, false), mFunctionName(functionName), mNodeFound(nullptr) : TIntermTraverser(true, false, false),
mFunctionMangledName(functionMangledName),
mNodeFound(nullptr)
{ {
} }
bool visitAggregate(Visit visit, TIntermAggregate *node) override bool visitAggregate(Visit visit, TIntermAggregate *node) override
{ {
if (node->isFunctionCall() && node->getFunctionSymbolInfo()->getName() == mFunctionName) if (node->isFunctionCall() &&
TFunction::GetMangledNameFromCall(node->getFunctionSymbolInfo()->getName(),
*node->getSequence()) == mFunctionMangledName)
{ {
mNodeFound = node; mNodeFound = node;
return false; return false;
...@@ -39,7 +43,7 @@ class FunctionCallFinder : public TIntermTraverser ...@@ -39,7 +43,7 @@ class FunctionCallFinder : public TIntermTraverser
const TIntermAggregate *getNode() const { return mNodeFound; } const TIntermAggregate *getNode() const { return mNodeFound; }
private: private:
TString mFunctionName; TString mFunctionMangledName;
TIntermAggregate *mNodeFound; TIntermAggregate *mNodeFound;
}; };
...@@ -208,9 +212,9 @@ bool MatchOutputCodeTest::notFoundInCode(const char *stringToFind) const ...@@ -208,9 +212,9 @@ bool MatchOutputCodeTest::notFoundInCode(const char *stringToFind) const
return true; return true;
} }
const TIntermAggregate *FindFunctionCallNode(TIntermNode *root, const TString &functionName) const TIntermAggregate *FindFunctionCallNode(TIntermNode *root, const TString &functionMangledName)
{ {
FunctionCallFinder finder(functionName); FunctionCallFinder finder(functionMangledName);
root->traverse(&finder); root->traverse(&finder);
return finder.getNode(); return finder.getNode();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment