Commit 0c37100d by Olli Etuaho Committed by Commit Bot

Always create TFunctions for function call nodes

This simplifies code and ensures that nodes get consistent data. In the future function call nodes could have a pointer to the TFunction instead of converting the same information into a different data structure. BUG=angleproject:2267 TEST=angle_unittests, angle_end2end_tests Change-Id: Ic0c24bb86b44b9bcc4a5da7f6b03701081a3af5c Reviewed-on: https://chromium-review.googlesource.com/824606Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 26143fdd
......@@ -39,9 +39,10 @@ TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall,
replacementArguments->push_back(arg);
}
replacementArguments->push_back(returnValueTarget);
TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
TType(EbtVoid), originalCall->getFunctionSymbolInfo()->getId(),
originalCall->getFunctionSymbolInfo()->getNameObj(), replacementArguments);
ASSERT(originalCall->getFunction());
TIntermAggregate *replacementCall =
TIntermAggregate::CreateFunctionCall(*originalCall->getFunction(), replacementArguments);
replacementCall->setType(TType(EbtVoid));
replacementCall->setLine(originalCall->getLine());
return replacementCall;
}
......
......@@ -148,7 +148,7 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
{
// Function call, add the callees
auto it = mFunctions.find(node->getFunctionSymbolInfo()->getId().get());
auto it = mFunctions.find(node->getFunction()->uniqueId().get());
ASSERT(it != mFunctions.end());
// We might be traversing the initializer of a global variable. Even though function
......
......@@ -99,19 +99,19 @@ void InsertInitCallToMain(TIntermBlock *root,
TIntermBlock *initGlobalsBlock = new TIntermBlock();
initGlobalsBlock->getSequence()->swap(*deferredInitializers);
TSymbolUniqueId initGlobalsFunctionId(symbolTable);
const char *kInitGlobalsFunctionName = "initGlobals";
TFunction *initGlobalsFunction =
new TFunction(symbolTable, NewPoolTString("initGlobals"), new TType(EbtVoid),
SymbolType::AngleInternal, false);
TIntermFunctionPrototype *initGlobalsFunctionPrototype =
CreateInternalFunctionPrototypeNode(TType(), kInitGlobalsFunctionName, initGlobalsFunctionId);
CreateInternalFunctionPrototypeNode(*initGlobalsFunction);
root->getSequence()->insert(root->getSequence()->begin(), initGlobalsFunctionPrototype);
TIntermFunctionDefinition *initGlobalsFunctionDefinition = CreateInternalFunctionDefinitionNode(
TType(), kInitGlobalsFunctionName, initGlobalsBlock, initGlobalsFunctionId);
TIntermFunctionDefinition *initGlobalsFunctionDefinition =
CreateInternalFunctionDefinitionNode(*initGlobalsFunction, initGlobalsBlock);
root->appendStatement(initGlobalsFunctionDefinition);
TIntermAggregate *initGlobalsCall = CreateInternalFunctionCallNode(
TType(), kInitGlobalsFunctionName, initGlobalsFunctionId, new TIntermSequence());
TIntermAggregate *initGlobalsCall =
TIntermAggregate::CreateFunctionCall(*initGlobalsFunction, new TIntermSequence());
TIntermBlock *mainBody = FindMainBody(root);
mainBody->getSequence()->insert(mainBody->getSequence()->begin(), initGlobalsCall);
......
......@@ -427,49 +427,6 @@ bool canRoundFloat(const TType &type)
(type.getPrecision() == EbpLow || type.getPrecision() == EbpMedium);
}
TIntermAggregate *createInternalFunctionCallNode(const TType &type,
TString name,
TIntermSequence *arguments)
{
TName nameObj(&name);
nameObj.setInternal(true);
TIntermAggregate *callNode =
TIntermAggregate::Create(type, EOpCallInternalRawFunction, arguments);
callNode->getFunctionSymbolInfo()->setNameObj(nameObj);
return callNode;
}
TIntermAggregate *createRoundingFunctionCallNode(TIntermTyped *roundedChild)
{
TString roundFunctionName;
if (roundedChild->getPrecision() == EbpMedium)
roundFunctionName = "angle_frm";
else
roundFunctionName = "angle_frl";
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(roundedChild);
TIntermAggregate *callNode =
createInternalFunctionCallNode(roundedChild->getType(), roundFunctionName, arguments);
callNode->getFunctionSymbolInfo()->setKnownToNotHaveSideEffects(true);
return callNode;
}
TIntermAggregate *createCompoundAssignmentFunctionCallNode(TIntermTyped *left,
TIntermTyped *right,
const char *opNameStr)
{
std::stringstream strstr;
if (left->getPrecision() == EbpMedium)
strstr << "angle_compound_" << opNameStr << "_frm";
else
strstr << "angle_compound_" << opNameStr << "_frl";
TString functionName = strstr.str().c_str();
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(left);
arguments->push_back(right);
return createInternalFunctionCallNode(left->getType(), functionName, arguments);
}
bool ParentUsesResult(TIntermNode *parent, TIntermTyped *node)
{
if (!parent)
......@@ -748,4 +705,50 @@ bool EmulatePrecision::SupportedInLanguage(const ShShaderOutput outputLanguage)
}
}
TFunction *EmulatePrecision::getInternalFunction(TString *functionName,
const TType &returnType,
TIntermSequence *arguments,
bool knownToNotHaveSideEffects)
{
TString mangledName = TFunction::GetMangledNameFromCall(*functionName, *arguments);
if (mInternalFunctions.find(mangledName) == mInternalFunctions.end())
{
mInternalFunctions[mangledName] =
new TFunction(mSymbolTable, functionName, new TType(returnType),
SymbolType::AngleInternal, knownToNotHaveSideEffects);
}
return mInternalFunctions[mangledName];
}
TIntermAggregate *EmulatePrecision::createRoundingFunctionCallNode(TIntermTyped *roundedChild)
{
const char *roundFunctionName;
if (roundedChild->getPrecision() == EbpMedium)
roundFunctionName = "angle_frm";
else
roundFunctionName = "angle_frl";
TString *functionName = NewPoolTString(roundFunctionName);
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(roundedChild);
return TIntermAggregate::CreateRawFunctionCall(
*getInternalFunction(functionName, roundedChild->getType(), arguments, true), arguments);
}
TIntermAggregate *EmulatePrecision::createCompoundAssignmentFunctionCallNode(TIntermTyped *left,
TIntermTyped *right,
const char *opNameStr)
{
std::stringstream strstr;
if (left->getPrecision() == EbpMedium)
strstr << "angle_compound_" << opNameStr << "_frm";
else
strstr << "angle_compound_" << opNameStr << "_frl";
TString *functionName = NewPoolTString(strstr.str().c_str());
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(left);
arguments->push_back(right);
return TIntermAggregate::CreateRawFunctionCall(
*getInternalFunction(functionName, left->getType(), arguments, false), arguments);
}
} // namespace sh
......@@ -59,12 +59,24 @@ class EmulatePrecision : public TLValueTrackingTraverser
}
};
TFunction *getInternalFunction(TString *functionName,
const TType &returnType,
TIntermSequence *arguments,
bool knownToNotHaveSideEffects);
TIntermAggregate *createRoundingFunctionCallNode(TIntermTyped *roundedChild);
TIntermAggregate *createCompoundAssignmentFunctionCallNode(TIntermTyped *left,
TIntermTyped *right,
const char *opNameStr);
typedef std::set<TypePair, TypePairComparator> EmulationSet;
EmulationSet mEmulateCompoundAdd;
EmulationSet mEmulateCompoundSub;
EmulationSet mEmulateCompoundMul;
EmulationSet mEmulateCompoundDiv;
// Map from mangled name to function.
TMap<TString, TFunction *> mInternalFunctions;
bool mDeclaringVariables;
};
......
......@@ -287,29 +287,20 @@ TIntermSymbol::TIntermSymbol(const TVariable *variable)
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode =
new TIntermAggregate(func.getReturnType(), EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(func);
return callNode;
return new TIntermAggregate(&func, func.getReturnType(), EOpCallFunctionInAST, arguments);
}
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TType &type,
const TSymbolUniqueId &id,
const TName &name,
TIntermSequence *arguments)
TIntermAggregate *TIntermAggregate::CreateRawFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode = new TIntermAggregate(type, EOpCallFunctionInAST, arguments);
callNode->getFunctionSymbolInfo()->setId(id);
callNode->getFunctionSymbolInfo()->setNameObj(name);
return callNode;
return new TIntermAggregate(&func, func.getReturnType(), EOpCallInternalRawFunction, arguments);
}
TIntermAggregate *TIntermAggregate::CreateBuiltInFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
TIntermAggregate *callNode =
new TIntermAggregate(func.getReturnType(), EOpCallBuiltInFunction, arguments);
callNode->getFunctionSymbolInfo()->setFromFunction(func);
new TIntermAggregate(&func, func.getReturnType(), EOpCallBuiltInFunction, arguments);
// Note that name needs to be set before texture function type is determined.
callNode->setBuiltInFunctionPrecision();
return callNode;
......@@ -318,27 +309,37 @@ TIntermAggregate *TIntermAggregate::CreateBuiltInFunctionCall(const TFunction &f
TIntermAggregate *TIntermAggregate::CreateConstructor(const TType &type,
TIntermSequence *arguments)
{
return new TIntermAggregate(type, EOpConstruct, arguments);
return new TIntermAggregate(nullptr, type, EOpConstruct, arguments);
}
TIntermAggregate *TIntermAggregate::Create(const TType &type,
TOperator op,
TIntermSequence *arguments)
{
TIntermAggregate *node = new TIntermAggregate(type, op, arguments);
ASSERT(op != EOpCallFunctionInAST); // Should use CreateFunctionCall
ASSERT(op != EOpCallInternalRawFunction); // Should use CreateRawFunctionCall
ASSERT(op != EOpCallBuiltInFunction); // Should use CreateBuiltInFunctionCall
ASSERT(!node->isConstructor()); // Should use CreateConstructor
return node;
ASSERT(op != EOpConstruct); // Should use CreateConstructor
return new TIntermAggregate(nullptr, type, op, arguments);
}
TIntermAggregate::TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments)
: TIntermOperator(op), mUseEmulatedFunction(false), mGotPrecisionFromChildren(false)
TIntermAggregate::TIntermAggregate(const TFunction *func,
const TType &type,
TOperator op,
TIntermSequence *arguments)
: TIntermOperator(op),
mUseEmulatedFunction(false),
mGotPrecisionFromChildren(false),
mFunction(func)
{
if (arguments != nullptr)
{
mArguments.swap(*arguments);
}
if (mFunction)
{
mFunctionInfo.setFromFunction(*mFunction);
}
setTypePrecisionAndQualifier(type);
}
......@@ -474,9 +475,23 @@ TString TIntermAggregate::getSymbolTableMangledName() const
}
}
const char *TIntermAggregate::functionName() const
{
ASSERT(!isConstructor());
switch (mOp)
{
case EOpCallInternalRawFunction:
case EOpCallBuiltInFunction:
case EOpCallFunctionInAST:
return mFunction->name()->c_str();
default:
return GetOperatorString(mOp);
}
}
bool TIntermAggregate::hasSideEffects() const
{
if (isFunctionCall() && mFunctionInfo.isKnownToNotHaveSideEffects())
if (isFunctionCall() && mFunction != nullptr && mFunction->isKnownToNotHaveSideEffects())
{
for (TIntermNode *arg : mArguments)
{
......@@ -580,17 +595,17 @@ TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : T
void TFunctionSymbolInfo::setFromFunction(const TFunction &function)
{
setName(*function.name());
mName.setString(*function.name());
mName.setInternal(function.symbolType() == SymbolType::AngleInternal);
setId(TSymbolUniqueId(function));
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TSymbolUniqueId &id)
: mId(new TSymbolUniqueId(id)), mKnownToNotHaveSideEffects(false)
TFunctionSymbolInfo::TFunctionSymbolInfo(const TSymbolUniqueId &id) : mId(new TSymbolUniqueId(id))
{
}
TFunctionSymbolInfo::TFunctionSymbolInfo(const TFunctionSymbolInfo &info)
: mName(info.mName), mId(nullptr), mKnownToNotHaveSideEffects(info.mKnownToNotHaveSideEffects)
: mName(info.mName), mId(nullptr)
{
if (info.mId)
{
......@@ -627,7 +642,8 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
: TIntermOperator(node),
mUseEmulatedFunction(node.mUseEmulatedFunction),
mGotPrecisionFromChildren(node.mGotPrecisionFromChildren),
mFunctionInfo(node.mFunctionInfo)
mFunctionInfo(node.mFunctionInfo),
mFunction(node.mFunction)
{
for (TIntermNode *arg : node.mArguments)
{
......@@ -642,8 +658,8 @@ TIntermAggregate *TIntermAggregate::shallowCopy() const
{
TIntermSequence *copySeq = new TIntermSequence();
copySeq->insert(copySeq->begin(), getSequence()->begin(), getSequence()->end());
TIntermAggregate *copyNode = new TIntermAggregate(mType, mOp, copySeq);
*copyNode->getFunctionSymbolInfo() = mFunctionInfo;
TIntermAggregate *copyNode = new TIntermAggregate(mFunction, mType, mOp, copySeq);
copyNode->mFunctionInfo = mFunctionInfo;
copyNode->setLine(mLine);
return copyNode;
}
......
......@@ -526,7 +526,7 @@ class TFunctionSymbolInfo
public:
POOL_ALLOCATOR_NEW_DELETE();
TFunctionSymbolInfo(const TSymbolUniqueId &id);
TFunctionSymbolInfo() : mId(nullptr), mKnownToNotHaveSideEffects(false) {}
TFunctionSymbolInfo() : mId(nullptr) {}
TFunctionSymbolInfo(const TFunctionSymbolInfo &info);
TFunctionSymbolInfo &operator=(const TFunctionSymbolInfo &info);
......@@ -537,15 +537,8 @@ class TFunctionSymbolInfo
const TName &getNameObj() const { return mName; }
const TString &getName() const { return mName.getString(); }
void setName(const TString &name) { mName.setString(name); }
bool isMain() const { return mName.getString() == "main"; }
void setKnownToNotHaveSideEffects(bool knownToNotHaveSideEffects)
{
mKnownToNotHaveSideEffects = knownToNotHaveSideEffects;
}
bool isKnownToNotHaveSideEffects() const { return mKnownToNotHaveSideEffects; }
void setId(const TSymbolUniqueId &functionId);
const TSymbolUniqueId &getId() const;
......@@ -557,7 +550,6 @@ class TFunctionSymbolInfo
private:
TName mName;
TSymbolUniqueId *mId;
bool mKnownToNotHaveSideEffects;
};
typedef TVector<TIntermNode *> TIntermSequence;
......@@ -598,12 +590,8 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
public:
static TIntermAggregate *CreateFunctionCall(const TFunction &func, TIntermSequence *arguments);
// If using this, ensure that there's a consistent function definition with the same symbol id
// added to the AST.
static TIntermAggregate *CreateFunctionCall(const TType &type,
const TSymbolUniqueId &id,
const TName &name,
TIntermSequence *arguments);
static TIntermAggregate *CreateRawFunctionCall(const TFunction &func,
TIntermSequence *arguments);
static TIntermAggregate *CreateBuiltInFunctionCall(const TFunction &func,
TIntermSequence *arguments);
......@@ -637,9 +625,13 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
// Returns true if changing parameter precision may affect the return value.
bool gotPrecisionFromChildren() const { return mGotPrecisionFromChildren; }
TFunctionSymbolInfo *getFunctionSymbolInfo() { return &mFunctionInfo; }
const TFunctionSymbolInfo *getFunctionSymbolInfo() const { return &mFunctionInfo; }
const TFunction *getFunction() const { return mFunction; }
// Get the function name to display to the user in an error message.
const char *functionName() const;
protected:
TIntermSequence mArguments;
......@@ -649,10 +641,16 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
bool mGotPrecisionFromChildren;
// TODO(oetuaho): Get rid of mFunctionInfo and just keep mFunction.
TFunctionSymbolInfo mFunctionInfo;
const TFunction *const mFunction;
private:
TIntermAggregate(const TType &type, TOperator op, TIntermSequence *arguments);
TIntermAggregate(const TFunction *func,
const TType &type,
TOperator op,
TIntermSequence *arguments);
TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private!
......
......@@ -16,14 +16,6 @@ namespace sh
namespace
{
TName GetInternalFunctionName(const char *name)
{
TString nameStr(name);
TName nameObj(&nameStr);
nameObj.setInternal(true);
return nameObj;
}
const TFunction *LookUpBuiltInFunction(const TString &name,
const TIntermSequence *arguments,
const TSymbolTable &symbolTable,
......@@ -41,33 +33,18 @@ const TFunction *LookUpBuiltInFunction(const TString &name,
} // anonymous namespace
TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId)
TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TFunction &func)
{
TIntermFunctionPrototype *functionNode = new TIntermFunctionPrototype(returnType, functionId);
functionNode->getFunctionSymbolInfo()->setNameObj(GetInternalFunctionName(name));
TIntermFunctionPrototype *functionNode =
new TIntermFunctionPrototype(func.getReturnType(), func.uniqueId());
functionNode->getFunctionSymbolInfo()->setFromFunction(func);
return functionNode;
}
TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(const TType &returnType,
const char *name,
TIntermBlock *functionBody,
const TSymbolUniqueId &functionId)
TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(const TFunction &func,
TIntermBlock *functionBody)
{
TIntermFunctionPrototype *prototypeNode =
CreateInternalFunctionPrototypeNode(returnType, name, functionId);
return new TIntermFunctionDefinition(prototypeNode, functionBody);
}
TIntermAggregate *CreateInternalFunctionCallNode(const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId,
TIntermSequence *arguments)
{
TIntermAggregate *functionNode = TIntermAggregate::CreateFunctionCall(
returnType, functionId, GetInternalFunctionName(name), arguments);
return functionNode;
return new TIntermFunctionDefinition(CreateInternalFunctionPrototypeNode(func), functionBody);
}
TIntermTyped *CreateZeroNode(const TType &type)
......
......@@ -17,17 +17,9 @@ namespace sh
class TSymbolTable;
class TVariable;
TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId);
TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(const TType &returnType,
const char *name,
TIntermBlock *functionBody,
const TSymbolUniqueId &functionId);
TIntermAggregate *CreateInternalFunctionCallNode(const TType &returnType,
const char *name,
const TSymbolUniqueId &functionId,
TIntermSequence *arguments);
TIntermFunctionPrototype *CreateInternalFunctionPrototypeNode(const TFunction &func);
TIntermFunctionDefinition *CreateInternalFunctionDefinitionNode(const TFunction &func,
TIntermBlock *functionBody);
TIntermTyped *CreateZeroNode(const TType &type);
TIntermConstantUnion *CreateIndexNode(int index);
......
......@@ -13,7 +13,7 @@ namespace sh
namespace
{
void OutputFunction(TInfoSinkBase &out, const char *str, TFunctionSymbolInfo *info)
void OutputFunction(TInfoSinkBase &out, const char *str, const TFunctionSymbolInfo *info)
{
const char *internal = info->getNameObj().isInternal() ? " (internal function)" : "";
out << str << internal << ": " << info->getNameObj().getString() << " (symbol id "
......
......@@ -1682,7 +1682,7 @@ void TParseContext::functionCallRValueLValueErrorCheck(const TFunction *fnCandid
{
error(argument->getLine(),
"Writeonly value cannot be passed for 'in' or 'inout' parameters.",
fnCall->getFunctionSymbolInfo()->getName().c_str());
fnCall->functionName());
return;
}
}
......@@ -1692,7 +1692,7 @@ void TParseContext::functionCallRValueLValueErrorCheck(const TFunction *fnCandid
{
error(argument->getLine(),
"Constant value cannot be passed for 'out' or 'inout' parameters.",
fnCall->getFunctionSymbolInfo()->getName().c_str());
fnCall->functionName());
return;
}
}
......@@ -3442,7 +3442,7 @@ TFunction *TParseContext::parseFunctionHeader(const TPublicType &type,
}
// Add the function as a prototype after parsing it (we do not support recursion)
return new TFunction(&symbolTable, name, new TType(type), SymbolType::UserDefined);
return new TFunction(&symbolTable, name, new TType(type), SymbolType::UserDefined, false);
}
TFunction *TParseContext::addNonConstructorFunc(const TString *name, const TSourceLoc &loc)
......@@ -3454,7 +3454,7 @@ TFunction *TParseContext::addNonConstructorFunc(const TString *name, const TSour
// would be enough, but TFunction carries a lot of extra information in addition to that.
// Besides function calls we do have to store constructor calls in the same data structure, for
// them we need to store a TType.
return new TFunction(&symbolTable, name, returnType, SymbolType::NotResolved);
return new TFunction(&symbolTable, name, returnType, SymbolType::NotResolved, false);
}
TFunction *TParseContext::addConstructorFunc(const TPublicType &publicType)
......@@ -3478,7 +3478,7 @@ TFunction *TParseContext::addConstructorFunc(const TPublicType &publicType)
type->setBasicType(EbtFloat);
}
return new TFunction(&symbolTable, nullptr, type, SymbolType::NotResolved, EOpConstruct);
return new TFunction(&symbolTable, nullptr, type, SymbolType::NotResolved, true, EOpConstruct);
}
void TParseContext::checkIsNotUnsizedArray(const TSourceLoc &line,
......@@ -5757,7 +5757,7 @@ TIntermTyped *TParseContext::addFunctionCallOrMethod(TFunction *fnCall,
{
if (thisNode != nullptr)
{
return addMethod(fnCall, arguments, thisNode, loc);
return addMethod(fnCall->name(), arguments, thisNode, loc);
}
TOperator op = fnCall->getBuiltInOp();
......@@ -5768,11 +5768,11 @@ TIntermTyped *TParseContext::addFunctionCallOrMethod(TFunction *fnCall,
else
{
ASSERT(op == EOpNull);
return addNonConstructorFunctionCall(fnCall, arguments, loc);
return addNonConstructorFunctionCall(fnCall->name(), arguments, loc);
}
}
TIntermTyped *TParseContext::addMethod(TFunction *fnCall,
TIntermTyped *TParseContext::addMethod(const TString *name,
TIntermSequence *arguments,
TIntermNode *thisNode,
const TSourceLoc &loc)
......@@ -5782,9 +5782,9 @@ TIntermTyped *TParseContext::addMethod(TFunction *fnCall,
// a constructor. But such a TFunction can't reach here, since the lexer goes into FIELDS
// mode after a dot, which makes type identifiers to be parsed as FIELD_SELECTION instead.
// So accessing fnCall->name() below is safe.
if (*fnCall->name() != "length")
if (*name != "length")
{
error(loc, "invalid method", fnCall->name()->c_str());
error(loc, "invalid method", name->c_str());
}
else if (!arguments->empty())
{
......@@ -5809,26 +5809,27 @@ TIntermTyped *TParseContext::addMethod(TFunction *fnCall,
return CreateZeroNode(TType(EbtInt, EbpUndefined, EvqConst));
}
TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall,
TIntermTyped *TParseContext::addNonConstructorFunctionCall(const TString *name,
TIntermSequence *arguments,
const TSourceLoc &loc)
{
ASSERT(name);
// First find by unmangled name to check whether the function name has been
// hidden by a variable name or struct typename.
// If a function is found, check for one with a matching argument list.
bool builtIn;
const TSymbol *symbol = symbolTable.find(*fnCall->name(), mShaderVersion, &builtIn);
const TSymbol *symbol = symbolTable.find(*name, mShaderVersion, &builtIn);
if (symbol != nullptr && !symbol->isFunction())
{
error(loc, "function name expected", fnCall->name()->c_str());
error(loc, "function name expected", name->c_str());
}
else
{
symbol = symbolTable.find(TFunction::GetMangledNameFromCall(*fnCall->name(), *arguments),
symbol = symbolTable.find(TFunction::GetMangledNameFromCall(*name, *arguments),
mShaderVersion, &builtIn);
if (symbol == nullptr)
{
error(loc, "no matching overloaded function found", fnCall->name()->c_str());
error(loc, "no matching overloaded function found", name->c_str());
}
else
{
......
......@@ -549,14 +549,14 @@ class TParseContext : angle::NonCopyable
const TSourceLoc &loc);
TIntermTyped *createUnaryMath(TOperator op, TIntermTyped *child, const TSourceLoc &loc);
TIntermTyped *addMethod(TFunction *fnCall,
TIntermTyped *addMethod(const TString *name,
TIntermSequence *arguments,
TIntermNode *thisNode,
const TSourceLoc &loc);
TIntermTyped *addConstructor(TIntermSequence *arguments,
TType type,
const TSourceLoc &line);
TIntermTyped *addNonConstructorFunctionCall(TFunction *fnCall,
TIntermTyped *addNonConstructorFunctionCall(const TString *name,
TIntermSequence *arguments,
const TSourceLoc &loc);
......
......@@ -101,17 +101,17 @@ TIntermTyped *EnsureSignedInt(TIntermTyped *node)
return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
}
TType GetFieldType(const TType &indexedType)
TType *GetFieldType(const TType &indexedType)
{
if (indexedType.isMatrix())
{
TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
TType *fieldType = new TType(indexedType.getBasicType(), indexedType.getPrecision());
fieldType->setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
return fieldType;
}
else
{
return TType(indexedType.getBasicType(), indexedType.getPrecision());
return new TType(indexedType.getBasicType(), indexedType.getPrecision());
}
}
......@@ -160,12 +160,12 @@ TType GetFieldType(const TType &indexedType)
// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
bool write,
const TSymbolUniqueId &functionId,
const TFunction &func,
TSymbolTable *symbolTable)
{
ASSERT(!type.isArray());
TType fieldType = GetFieldType(type);
const TType *fieldType = GetFieldType(type);
int numCases = 0;
if (type.isMatrix())
{
......@@ -176,15 +176,8 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
numCases = type.getNominalSize();
}
TType returnType(EbtVoid);
if (!write)
{
returnType = fieldType;
}
std::string functionName = GetIndexFunctionName(type, write);
TIntermFunctionPrototype *prototypeNode =
CreateInternalFunctionPrototypeNode(returnType, functionName.c_str(), functionId);
TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
TType baseType(type);
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
......@@ -203,7 +196,7 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
TIntermSymbol *valueParam = nullptr;
if (write)
{
valueParam = CreateValueSymbol(fieldType, symbolTable);
valueParam = CreateValueSymbol(*fieldType, symbolTable);
prototypeNode->getSequence()->push_back(valueParam);
}
......@@ -301,8 +294,8 @@ class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
// Maps of types that are indexed to the indexing function ids used for them. Note that these
// can not store multiple variants of the same type with different precisions - only one
// precision gets stored.
std::map<TType, TSymbolUniqueId *> mIndexedVecAndMatrixTypes;
std::map<TType, TSymbolUniqueId *> mWrittenVecAndMatrixTypes;
std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
bool mUsedTreeInsertion;
......@@ -347,26 +340,23 @@ void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
// Create a call to dyn_index_*() based on an indirect indexing op node
TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
TIntermTyped *index,
const TSymbolUniqueId &functionId)
TFunction *indexingFunction)
{
ASSERT(node->getOp() == EOpIndexIndirect);
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(node->getLeft());
arguments->push_back(index);
TType fieldType = GetFieldType(node->getLeft()->getType());
std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), false);
TIntermAggregate *indexingCall =
CreateInternalFunctionCallNode(fieldType, functionName.c_str(), functionId, arguments);
TIntermAggregate::CreateFunctionCall(*indexingFunction, arguments);
indexingCall->setLine(node->getLine());
indexingCall->getFunctionSymbolInfo()->setKnownToNotHaveSideEffects(true);
return indexingCall;
}
TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
TVariable *index,
TVariable *writtenValue,
const TSymbolUniqueId &functionId)
TFunction *indexedWriteFunction)
{
ASSERT(node->getOp() == EOpIndexIndirect);
TIntermSequence *arguments = new TIntermSequence();
......@@ -375,9 +365,8 @@ TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
arguments->push_back(CreateTempSymbolNode(index));
arguments->push_back(CreateTempSymbolNode(writtenValue));
std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), true);
TIntermAggregate *indexedWriteCall =
CreateInternalFunctionCallNode(TType(EbtVoid), functionName.c_str(), functionId, arguments);
TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, arguments);
indexedWriteCall->setLine(node->getLine());
return indexedWriteCall;
}
......@@ -424,14 +413,19 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
#endif
const TType &type = node->getLeft()->getType();
TSymbolUniqueId *indexingFunctionId = new TSymbolUniqueId(mSymbolTable);
TString *indexingFunctionName =
NewPoolTString(GetIndexFunctionName(type, false).c_str());
TFunction *indexingFunction = nullptr;
if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
{
mIndexedVecAndMatrixTypes[type] = indexingFunctionId;
indexingFunction =
new TFunction(mSymbolTable, indexingFunctionName, GetFieldType(type),
SymbolType::AngleInternal, true);
mIndexedVecAndMatrixTypes[type] = indexingFunction;
}
else
{
indexingFunctionId = mIndexedVecAndMatrixTypes[type];
indexingFunction = mIndexedVecAndMatrixTypes[type];
}
if (write)
......@@ -466,14 +460,19 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
// only writes it and doesn't need the previous value. http://anglebug.com/1116
TSymbolUniqueId *indexedWriteFunctionId = new TSymbolUniqueId(mSymbolTable);
TFunction *indexedWriteFunction = nullptr;
if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
{
mWrittenVecAndMatrixTypes[type] = indexedWriteFunctionId;
TString *functionName = NewPoolTString(
GetIndexFunctionName(node->getLeft()->getType(), true).c_str());
indexedWriteFunction =
new TFunction(mSymbolTable, functionName, new TType(EbtVoid),
SymbolType::AngleInternal, false);
mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
}
else
{
indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type];
indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
}
TIntermSequence insertionsBefore;
......@@ -489,7 +488,7 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// s1 = dyn_index(v_expr, s0);
TIntermAggregate *indexingCall = CreateIndexFunctionCall(
node, CreateTempSymbolNode(indexVariable), *indexingFunctionId);
node, CreateTempSymbolNode(indexVariable), indexingFunction);
TIntermDeclaration *fieldVariableDeclaration = nullptr;
TVariable *fieldVariable = DeclareTempVariable(
mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
......@@ -497,7 +496,7 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// dyn_index_write(v_expr, s0, s1);
TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
node, indexVariable, fieldVariable, *indexedWriteFunctionId);
node, indexVariable, fieldVariable, indexedWriteFunction);
insertionsAfter.push_back(indexedWriteCall);
insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
......@@ -514,7 +513,7 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
// If the index_expr is unsigned, we'll convert it to signed.
ASSERT(!mRemoveIndexSideEffectsInSubtree);
TIntermAggregate *indexingCall = CreateIndexFunctionCall(
node, EnsureSignedInt(node->getRight()), *indexingFunctionId);
node, EnsureSignedInt(node->getRight()), indexingFunction);
queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
}
}
......
......@@ -66,33 +66,35 @@ void WrapMainAndAppend(TIntermBlock *root,
TSymbolTable *symbolTable)
{
// Replace main() with main0() with the same body.
TSymbolUniqueId oldMainId(symbolTable);
std::stringstream oldMainName;
oldMainName << "main" << oldMainId.get();
TIntermFunctionDefinition *oldMain = CreateInternalFunctionDefinitionNode(
TType(EbtVoid), oldMainName.str().c_str(), main->getBody(), oldMainId);
TFunction *oldMain =
new TFunction(symbolTable, nullptr, new TType(EbtVoid), SymbolType::AngleInternal, false);
TIntermFunctionDefinition *oldMainDefinition =
CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
bool replaced = root->replaceChildNode(main, oldMain);
bool replaced = root->replaceChildNode(main, oldMainDefinition);
ASSERT(replaced);
// void main()
TFunction *newMain = new TFunction(symbolTable, NewPoolTString("main"), new TType(EbtVoid),
SymbolType::UserDefined, false);
TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(
TType(EbtVoid), main->getFunctionPrototype()->getFunctionSymbolInfo()->getId());
newMainProto->getFunctionSymbolInfo()->setName("main");
newMainProto->getFunctionSymbolInfo()->setFromFunction(*newMain);
// {
// main0();
// codeToRun
// }
TIntermBlock *newMainBody = new TIntermBlock();
TIntermAggregate *oldMainCall = CreateInternalFunctionCallNode(
TType(EbtVoid), oldMainName.str().c_str(), oldMainId, new TIntermSequence());
TIntermAggregate *oldMainCall =
TIntermAggregate::CreateFunctionCall(*oldMain, new TIntermSequence());
newMainBody->appendStatement(oldMainCall);
newMainBody->appendStatement(codeToRun);
// Add the new main() to the root node.
TIntermFunctionDefinition *newMain = new TIntermFunctionDefinition(newMainProto, newMainBody);
root->appendStatement(newMain);
TIntermFunctionDefinition *newMainDefinition =
new TIntermFunctionDefinition(newMainProto, newMainBody);
root->appendStatement(newMainDefinition);
}
} // anonymous namespace
......
......@@ -117,6 +117,7 @@ TFunction::TFunction(TSymbolTable *symbolTable,
const TString *name,
const TType *retType,
SymbolType symbolType,
bool knownToNotHaveSideEffects,
TOperator tOp,
TExtension extension)
: TSymbol(symbolTable, name, symbolType, extension),
......@@ -124,8 +125,12 @@ TFunction::TFunction(TSymbolTable *symbolTable,
mangledName(nullptr),
op(tOp),
defined(false),
mHasPrototypeDeclaration(false)
mHasPrototypeDeclaration(false),
mKnownToNotHaveSideEffects(knownToNotHaveSideEffects)
{
// Functions with an empty name are not allowed.
ASSERT(symbolType != SymbolType::Empty);
ASSERT(name != nullptr || symbolType == SymbolType::AngleInternal || tOp != EOpNull);
}
//
......
......@@ -184,6 +184,7 @@ class TFunction : public TSymbol
const TString *name,
const TType *retType,
SymbolType symbolType,
bool knownToNotHaveSideEffects,
TOperator tOp = EOpNull,
TExtension extension = TExtension::UNDEFINED);
......@@ -222,6 +223,8 @@ class TFunction : public TSymbol
size_t getParamCount() const { return parameters.size(); }
const TConstParameter &getParam(size_t i) const { return parameters[i]; }
bool isKnownToNotHaveSideEffects() const { return mKnownToNotHaveSideEffects; }
private:
void clearParameters();
......@@ -231,9 +234,12 @@ class TFunction : public TSymbol
TParamList parameters;
const TType *returnType;
mutable const TString *mangledName;
// TODO(oetuaho): Remove op from TFunction once TFunction is not used for looking up builtins or
// constructors.
TOperator op;
bool defined;
bool mHasPrototypeDeclaration;
bool mKnownToNotHaveSideEffects;
};
} // namespace sh
......
......@@ -415,7 +415,7 @@ void TSymbolTable::insertBuiltIn(ESymbolLevel level,
else
{
TFunction *function =
new TFunction(this, NewPoolTString(name), rvalue, SymbolType::BuiltIn, op, ext);
new TFunction(this, NewPoolTString(name), rvalue, SymbolType::BuiltIn, false, op, ext);
function->addParameter(TConstParameter(ptype1));
......@@ -481,7 +481,8 @@ void TSymbolTable::insertBuiltInFunctionNoParameters(ESymbolLevel level,
const char *name)
{
insertUnmangledBuiltInName(name, level);
insert(level, new TFunction(this, NewPoolTString(name), rvalue, SymbolType::BuiltIn, op));
insert(level,
new TFunction(this, NewPoolTString(name), rvalue, SymbolType::BuiltIn, false, op));
}
void TSymbolTable::insertBuiltInFunctionNoParametersExt(ESymbolLevel level,
......@@ -491,7 +492,8 @@ void TSymbolTable::insertBuiltInFunctionNoParametersExt(ESymbolLevel level,
const char *name)
{
insertUnmangledBuiltInName(name, level);
insert(level, new TFunction(this, NewPoolTString(name), rvalue, SymbolType::BuiltIn, op, ext));
insert(level,
new TFunction(this, NewPoolTString(name), rvalue, SymbolType::BuiltIn, false, op, ext));
}
TPrecision TSymbolTable::getDefaultPrecision(TBasicType type) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment