Commit d4bd963f by Olli Etuaho Committed by Commit Bot

Don't use TIntermSymbol nodes for function parameters

Parameter nodes are not needed - it's simpler to just create a TVariable object for each parameter when the TFunction is initialized. With this change we also store only one object per each parameter type used in built-in functions, instead of one array of TConstParameter entries for each unique parameter sequence. This simplifies code and reduces binary size and compiler memory use. Compiler perf does not seem to be significantly affected. BUG=angleproject:2267 TEST=angle_unittests Change-Id: I2b82400dd594731074309f92a705e75135a4c82c Reviewed-on: https://chromium-review.googlesource.com/955589 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org>
parent 44a73fcf
......@@ -23,15 +23,6 @@ namespace
constexpr const ImmutableString kReturnValueVariableName("angle_return");
void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to)
{
const TIntermSequence *fromSequence = from->getSequence();
for (size_t ii = 0; ii < fromSequence->size(); ++ii)
{
to->getSequence()->push_back(fromSequence->at(ii));
}
}
class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
{
public:
......@@ -40,7 +31,7 @@ class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
private:
ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBranch(Visit visit, TIntermBranch *node) override;
......@@ -110,10 +101,9 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
return true;
}
bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit,
TIntermFunctionPrototype *node)
void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
if (visit == PreVisit && node->isArray())
if (node->isArray())
{
// Replace the whole prototype node with another node that has the out parameter
// added. Also set the function to return void.
......@@ -133,21 +123,16 @@ bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit
{
func->addParameter(node->getFunction()->getParam(i));
}
func->addParameter(TConstParameter(
kReturnValueVariableName, static_cast<const TType *>(returnValueVariableType)));
func->addParameter(changedFunction.returnValueVariable);
changedFunction.func = func;
mChangedFunctions[functionId.get()] = changedFunction;
}
TIntermFunctionPrototype *replacement =
new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
CopyAggregateChildren(node, replacement);
replacement->getSequence()->push_back(
new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable));
replacement->setLine(node->getLine());
queueReplacement(replacement, OriginalNode::IS_DROPPED);
}
return false;
}
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
......
......@@ -116,16 +116,13 @@ class CallDAG::CallDAGCreator : public TIntermTraverser
return false;
}
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override
void visitFunctionPrototype(TIntermFunctionPrototype *node) override
{
ASSERT(mCurrentFunction == nullptr);
// Function declaration, create an empty record.
auto &record = mFunctions[node->getFunction()->uniqueId().get()];
record.name = node->getFunction()->name();
// No need to traverse the parameters.
return false;
}
// Track functions called from another function.
......
......@@ -618,11 +618,6 @@ bool EmulatePrecision::visitInvariantDeclaration(Visit visit, TIntermInvariantDe
return false;
}
bool EmulatePrecision::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{
return false;
}
bool EmulatePrecision::visitAggregate(Visit visit, TIntermAggregate *node)
{
if (visit != PreVisit)
......@@ -709,7 +704,7 @@ bool EmulatePrecision::SupportedInLanguage(const ShShaderOutput outputLanguage)
const TFunction *EmulatePrecision::getInternalFunction(const ImmutableString &functionName,
const TType &returnType,
TIntermSequence *arguments,
const TVector<TConstParameter> &parameters,
const TVector<const TVariable *> &parameters,
bool knownToNotHaveSideEffects)
{
ImmutableString mangledName = TFunctionLookup::GetMangledName(functionName.data(), *arguments);
......@@ -735,11 +730,13 @@ TIntermAggregate *EmulatePrecision::createRoundingFunctionCallNode(TIntermTyped
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(roundedChild);
TVector<TConstParameter> parameters;
TVector<const TVariable *> parameters;
TType *paramType = new TType(roundedChild->getType());
paramType->setPrecision(EbpHigh);
paramType->setQualifier(EvqIn);
parameters.push_back(TConstParameter(kParamXName, static_cast<const TType *>(paramType)));
parameters.push_back(new TVariable(mSymbolTable, kParamXName,
static_cast<const TType *>(paramType),
SymbolType::AngleInternal));
return TIntermAggregate::CreateRawFunctionCall(
*getInternalFunction(*roundFunctionName, roundedChild->getType(), arguments, parameters,
......@@ -761,15 +758,19 @@ TIntermAggregate *EmulatePrecision::createCompoundAssignmentFunctionCallNode(TIn
arguments->push_back(left);
arguments->push_back(right);
TVector<TConstParameter> parameters;
TVector<const TVariable *> parameters;
TType *leftParamType = new TType(left->getType());
leftParamType->setPrecision(EbpHigh);
leftParamType->setQualifier(EvqOut);
parameters.push_back(TConstParameter(kParamXName, static_cast<const TType *>(leftParamType)));
parameters.push_back(new TVariable(mSymbolTable, kParamXName,
static_cast<const TType *>(leftParamType),
SymbolType::AngleInternal));
TType *rightParamType = new TType(right->getType());
rightParamType->setPrecision(EbpHigh);
rightParamType->setQualifier(EvqIn);
parameters.push_back(TConstParameter(kParamYName, static_cast<const TType *>(rightParamType)));
parameters.push_back(new TVariable(mSymbolTable, kParamYName,
static_cast<const TType *>(rightParamType),
SymbolType::AngleInternal));
return TIntermAggregate::CreateRawFunctionCall(
*getInternalFunction(functionName, left->getType(), arguments, parameters, false),
......
......@@ -32,7 +32,6 @@ class EmulatePrecision : public TLValueTrackingTraverser
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node) override;
bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void writeEmulationHelpers(TInfoSinkBase &sink,
const int shaderVersion,
......@@ -62,7 +61,7 @@ class EmulatePrecision : public TLValueTrackingTraverser
const TFunction *getInternalFunction(const ImmutableString &functionName,
const TType &returnType,
TIntermSequence *arguments,
const TVector<TConstParameter> &parameters,
const TVector<const TVariable *> &parameters,
bool knownToNotHaveSideEffects);
TIntermAggregate *createRoundingFunctionCallNode(TIntermTyped *roundedChild);
TIntermAggregate *createCompoundAssignmentFunctionCallNode(TIntermTyped *left,
......
......@@ -264,7 +264,7 @@ bool TIntermBlock::replaceChildNode(TIntermNode *original, TIntermNode *replacem
bool TIntermFunctionPrototype::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
return replaceChildNodeInternal(original, replacement);
return false;
}
bool TIntermDeclaration::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
......@@ -689,12 +689,6 @@ void TIntermBlock::appendStatement(TIntermNode *statement)
}
}
void TIntermFunctionPrototype::appendParameter(TIntermSymbol *parameter)
{
ASSERT(parameter != nullptr);
mParameters.push_back(parameter);
}
void TIntermDeclaration::appendDeclarator(TIntermTyped *declarator)
{
ASSERT(declarator != nullptr);
......
......@@ -660,7 +660,7 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase
// Function prototype. May be in the AST either as a function prototype declaration or as a part of
// a function definition. The type of the node is the function return type.
class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBase
class TIntermFunctionPrototype : public TIntermTyped
{
public:
TIntermFunctionPrototype(const TFunction *function);
......@@ -683,17 +683,9 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas
return true;
}
// Only intended for initially building the declaration.
void appendParameter(TIntermSymbol *parameter);
TIntermSequence *getSequence() override { return &mParameters; }
const TIntermSequence *getSequence() const override { return &mParameters; }
const TFunction *getFunction() const { return mFunction; }
protected:
TIntermSequence mParameters;
const TFunction *const mFunction;
};
......
......@@ -496,29 +496,7 @@ void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{
ScopedNodeInTraversalPath addToPath(this, node);
bool visit = true;
TIntermSequence *sequence = node->getSequence();
if (preVisit)
visit = visitFunctionPrototype(PreVisit, node);
if (visit)
{
for (auto *child : *sequence)
{
child->traverse(this);
if (visit && inVisit)
{
if (child != sequence->back())
visit = visitFunctionPrototype(InVisit, node);
}
}
}
if (visit && postVisit)
visitFunctionPrototype(PostVisit, node);
visitFunctionPrototype(node);
}
// Traverse an aggregate node. Same comments in binary node apply here.
......@@ -673,7 +651,7 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
// Both built-ins and user defined functions should have the function symbol set.
ASSERT(paramIndex < node->getFunction()->getParamCount());
TQualifier qualifier =
node->getFunction()->getParam(paramIndex).type->getQualifier();
node->getFunction()->getParam(paramIndex)->getType().getQualifier();
setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
++paramIndex;
}
......
......@@ -53,10 +53,7 @@ class TIntermTraverser : angle::NonCopyable
virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
virtual bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
{
return true;
}
virtual void visitFunctionPrototype(TIntermFunctionPrototype *node) {}
virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{
return true;
......
......@@ -328,24 +328,23 @@ void TOutputGLSLBase::writeVariableType(const TType &type)
}
}
void TOutputGLSLBase::writeFunctionParameters(const TIntermSequence &args)
void TOutputGLSLBase::writeFunctionParameters(const TFunction *func)
{
TInfoSinkBase &out = objSink();
for (TIntermSequence::const_iterator iter = args.begin(); iter != args.end(); ++iter)
size_t paramCount = func->getParamCount();
for (size_t i = 0; i < paramCount; ++i)
{
const TIntermSymbol *arg = (*iter)->getAsSymbolNode();
ASSERT(arg != nullptr);
const TType &type = arg->getType();
const TVariable *param = func->getParam(i);
const TType &type = param->getType();
writeVariableType(type);
if (arg->variable().symbolType() != SymbolType::Empty)
out << " " << hashName(&arg->variable());
if (param->symbolType() != SymbolType::Empty)
out << " " << hashName(param);
if (type.isArray())
out << ArrayString(type);
// Put a comma if this is not the last argument.
if (iter != args.end() - 1)
if (i != paramCount - 1)
out << ", ";
}
}
......@@ -890,10 +889,9 @@ bool TOutputGLSLBase::visitInvariantDeclaration(Visit visit, TIntermInvariantDec
return false;
}
bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
void TOutputGLSLBase::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
TInfoSinkBase &out = objSink();
ASSERT(visit == PreVisit);
const TType &type = node->getType();
writeVariableType(type);
......@@ -903,10 +901,8 @@ bool TOutputGLSLBase::visitFunctionPrototype(Visit visit, TIntermFunctionPrototy
out << " " << hashFunctionNameIfNeeded(node->getFunction());
out << "(";
writeFunctionParameters(*(node->getSequence()));
writeFunctionParameters(node->getFunction());
out << ")";
return false;
}
bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate *node)
......
......@@ -44,7 +44,7 @@ class TOutputGLSLBase : public TIntermTraverser
void writeInvariantQualifier(const TType &type);
void writeVariableType(const TType &type);
virtual bool writeVariablePrecision(TPrecision precision) = 0;
void writeFunctionParameters(const TIntermSequence &args);
void writeFunctionParameters(const TFunction *func);
const TConstantUnion *writeConstantUnion(const TType &type, const TConstantUnion *pConstUnion);
void writeConstructorTriplet(Visit visit, const TType &type);
ImmutableString getTypeName(const TType &type);
......@@ -58,7 +58,7 @@ class TOutputGLSLBase : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBlock(Visit visit, TIntermBlock *node) override;
......
......@@ -1756,36 +1756,31 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition
out << TypeString(node->getFunctionPrototype()->getType()) << " ";
TIntermSequence *parameters = node->getFunctionPrototype()->getSequence();
const TFunction *func = node->getFunction();
if (node->getFunction()->isMain())
if (func->isMain())
{
out << "gl_main(";
}
else
{
out << DecorateFunctionIfNeeded(node->getFunction()) << DisambiguateFunctionName(parameters)
out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
<< (mOutputLod0Function ? "Lod0(" : "(");
}
for (unsigned int i = 0; i < parameters->size(); i++)
size_t paramCount = func->getParamCount();
for (unsigned int i = 0; i < paramCount; i++)
{
TIntermSymbol *symbol = (*parameters)[i]->getAsSymbolNode();
const TVariable *param = func->getParam(i);
ensureStructDefined(param->getType());
if (symbol)
{
ensureStructDefined(symbol->getType());
writeParameter(symbol, out);
writeParameter(param, out);
if (i < parameters->size() - 1)
if (i < paramCount - 1)
{
out << ", ";
}
}
else
UNREACHABLE();
}
out << ")\n";
......@@ -1871,32 +1866,29 @@ bool OutputHLSL::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclarat
return false;
}
bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
TInfoSinkBase &out = getInfoSink();
ASSERT(visit == PreVisit);
size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
// Skip the prototype if it is not implemented (and thus not used)
if (index == CallDAG::InvalidIndex)
{
return false;
return;
}
TIntermSequence *arguments = node->getSequence();
const TFunction *func = node->getFunction();
TString name = DecorateFunctionIfNeeded(node->getFunction());
out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments)
TString name = DecorateFunctionIfNeeded(func);
out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
<< (mOutputLod0Function ? "Lod0(" : "(");
for (unsigned int i = 0; i < arguments->size(); i++)
size_t paramCount = func->getParamCount();
for (unsigned int i = 0; i < paramCount; i++)
{
TIntermSymbol *symbol = (*arguments)[i]->getAsSymbolNode();
ASSERT(symbol != nullptr);
writeParameter(func->getParam(i), out);
writeParameter(symbol, out);
if (i < arguments->size() - 1)
if (i < paramCount - 1)
{
out << ", ";
}
......@@ -1912,8 +1904,6 @@ bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *n
node->traverse(this);
mOutputLod0Function = false;
}
return false;
}
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
......@@ -2642,22 +2632,13 @@ void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
}
}
void OutputHLSL::writeParameter(const TIntermSymbol *symbol, TInfoSinkBase &out)
void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
{
TQualifier qualifier = symbol->getQualifier();
const TType &type = symbol->getType();
const TVariable &variable = symbol->variable();
TString nameStr;
const TType &type = param->getType();
TQualifier qualifier = type.getQualifier();
if (variable.symbolType() ==
SymbolType::Empty) // HLSL demands named arguments, also for prototypes
{
nameStr = "x" + str(mUniqueIndex++);
}
else
{
nameStr = DecorateVariableIfNeeded(variable);
}
TString nameStr = DecorateVariableIfNeeded(*param);
ASSERT(nameStr != ""); // HLSL demands named arguments, also for prototypes
if (IsSampler(type.getBasicType()))
{
......
......@@ -93,7 +93,7 @@ class OutputHLSL : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *) override;
bool visitSwitch(Visit visit, TIntermSwitch *) override;
bool visitCase(Visit visit, TIntermCase *) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitBlock(Visit visit, TIntermBlock *node) override;
......@@ -112,7 +112,7 @@ class OutputHLSL : public TIntermTraverser
const char *inString,
const char *postString);
void outputLineDirective(TInfoSinkBase &out, int line);
void writeParameter(const TIntermSymbol *symbol, TInfoSinkBase &out);
void writeParameter(const TVariable *param, TInfoSinkBase &out);
void outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node);
const TConstantUnion *writeConstantUnion(TInfoSinkBase &out,
......
......@@ -43,7 +43,7 @@ class TOutputTraverser : public TIntermTraverser
bool visitIfElse(Visit visit, TIntermIfElse *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitBlock(Visit visit, TIntermBlock *) override;
......@@ -361,14 +361,20 @@ bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDe
return true;
}
bool TOutputTraverser::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
OutputTreeText(mOut, node, mDepth);
OutputFunction(mOut, "Function Prototype", node->getFunction());
mOut << " (" << node->getCompleteString() << ")";
mOut << "\n";
return true;
size_t paramCount = node->getFunction()->getParamCount();
for (size_t i = 0; i < paramCount; ++i)
{
const TVariable *param = node->getFunction()->getParam(i);
OutputTreeText(mOut, node, mDepth + 1);
mOut << "parameter: " << param->name() << " (" << param->getType().getCompleteString()
<< ")";
}
}
bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
......
......@@ -1675,7 +1675,7 @@ void TParseContext::functionCallRValueLValueErrorCheck(const TFunction *fnCandid
{
for (size_t i = 0; i < fnCandidate->getParamCount(); ++i)
{
TQualifier qual = fnCandidate->getParam(i).type->getQualifier();
TQualifier qual = fnCandidate->getParam(i)->getType().getQualifier();
TIntermTyped *argument = (*(fnCall->getSequence()))[i]->getAsTyped();
if (!IsImage(argument->getBasicType()) && (IsQualifierUnspecified(qual) || qual == EvqIn ||
qual == EvqInOut || qual == EvqConstReadOnly))
......@@ -3155,47 +3155,31 @@ TIntermFunctionPrototype *TParseContext::createPrototypeNodeFromFunction(
for (size_t i = 0; i < function.getParamCount(); i++)
{
const TConstParameter &param = function.getParam(i);
TIntermSymbol *symbol = nullptr;
const TVariable *param = function.getParam(i);
// If the parameter has no name, it's not an error, just don't add it to symbol table (could
// be used for unused args).
if (param.name != nullptr)
if (param->symbolType() != SymbolType::Empty)
{
TVariable *variable =
new TVariable(&symbolTable, param.name, param.type, SymbolType::UserDefined);
symbol = new TIntermSymbol(variable);
// Insert the parameter in the symbol table.
if (insertParametersToSymbolTable)
{
if (!symbolTable.declare(variable))
if (!symbolTable.declare(const_cast<TVariable *>(param)))
{
error(location, "redefinition", param.name);
error(location, "redefinition", param->name());
}
}
// Unsized type of a named parameter should have already been checked and sanitized.
ASSERT(!param.type->isUnsizedArray());
ASSERT(!param->getType().isUnsizedArray());
}
else
{
if (param.type->isUnsizedArray())
if (param->getType().isUnsizedArray())
{
error(location, "function parameter array must be sized at compile time", "[]");
// We don't need to size the arrays since the parameter is unnamed and hence
// inaccessible.
}
}
if (!symbol)
{
// The parameter had no name or declaring the symbol failed - either way, add a nameless
// symbol.
TVariable *emptyVariable =
new TVariable(&symbolTable, ImmutableString(""), param.type, SymbolType::Empty);
symbol = new TIntermSymbol(emptyVariable);
}
symbol->setLine(location);
prototype->appendParameter(symbol);
}
return prototype;
}
......@@ -3288,8 +3272,8 @@ TFunction *TParseContext::parseFunctionDeclarator(const TSourceLoc &location, TF
for (size_t i = 0u; i < function->getParamCount(); ++i)
{
auto &param = function->getParam(i);
if (param.type->isStructSpecifier())
const TVariable *param = function->getParam(i);
if (param->getType().isStructSpecifier())
{
// ESSL 3.00.6 section 12.10.
error(location, "Function parameter type cannot be a structure definition",
......@@ -3335,12 +3319,12 @@ TFunction *TParseContext::parseFunctionDeclarator(const TSourceLoc &location, TF
}
for (size_t i = 0; i < prevDec->getParamCount(); ++i)
{
if (prevDec->getParam(i).type->getQualifier() !=
function->getParam(i).type->getQualifier())
if (prevDec->getParam(i)->getType().getQualifier() !=
function->getParam(i)->getType().getQualifier())
{
error(location,
"function must have the same parameter qualifiers in all of its declarations",
function->getParam(i).type->getQualifierString());
function->getParam(i)->getType().getQualifierString());
}
}
}
......@@ -5691,7 +5675,7 @@ void TParseContext::checkImageMemoryAccessForUserDefinedFunctions(
{
TIntermTyped *typedArgument = arguments[i]->getAsTyped();
const TType &functionArgumentType = typedArgument->getType();
const TType &functionParameterType = *functionDefinition->getParam(i).type;
const TType &functionParameterType = functionDefinition->getParam(i)->getType();
ASSERT(functionArgumentType.getBasicType() == functionParameterType.getBasicType());
if (IsImage(functionArgumentType.getBasicType()))
......
......@@ -21,42 +21,42 @@ namespace BuiltInGroup
bool isTextureOffsetNoBias(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 605 && id <= 674;
return id >= 662 && id <= 731;
}
bool isTextureOffsetBias(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 675 && id <= 694;
return id >= 732 && id <= 751;
}
bool isTextureGatherOffset(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 764 && id <= 777;
return id >= 823 && id <= 836;
}
bool isTextureGather(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 740 && id <= 777;
return id >= 799 && id <= 836;
}
bool isAtomicMemory(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 793 && id <= 808;
return id >= 853 && id <= 870;
}
bool isImageLoad(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 821 && id <= 832;
return id >= 895 && id <= 906;
}
bool isImageStore(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 833 && id <= 844;
return id >= 907 && id <= 918;
}
bool isImage(const TFunction *func)
{
int id = func->uniqueId().get();
return id >= 809 && id <= 844;
return id >= 871 && id <= 918;
}
} // namespace BuiltInGroup
......
......@@ -65,13 +65,6 @@ std::string GetIndexFunctionName(const TType &type, bool write)
return nameSink.str();
}
TIntermSymbol *CreateParameterSymbol(const TConstParameter &parameter, TSymbolTable *symbolTable)
{
TVariable *variable =
new TVariable(symbolTable, parameter.name, parameter.type, SymbolType::AngleInternal);
return new TIntermSymbol(variable);
}
TIntermConstantUnion *CreateIntConstantNode(int i)
{
TConstantUnion *constant = new TConstantUnion();
......@@ -180,15 +173,12 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
std::string functionName = GetIndexFunctionName(type, write);
TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
TIntermSymbol *baseParam = CreateParameterSymbol(func.getParam(0), symbolTable);
prototypeNode->getSequence()->push_back(baseParam);
TIntermSymbol *indexParam = CreateParameterSymbol(func.getParam(1), symbolTable);
prototypeNode->getSequence()->push_back(indexParam);
TIntermSymbol *baseParam = new TIntermSymbol(func.getParam(0));
TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1));
TIntermSymbol *valueParam = nullptr;
if (write)
{
valueParam = CreateParameterSymbol(func.getParam(2), symbolTable);
prototypeNode->getSequence()->push_back(valueParam);
valueParam = new TIntermSymbol(func.getParam(2));
}
TIntermBlock *statementList = new TIntermBlock();
......@@ -408,9 +398,10 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
indexingFunction =
new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
GetFieldType(type), true);
indexingFunction->addParameter(new TVariable(
mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal));
indexingFunction->addParameter(
TConstParameter(kBaseName, GetBaseType(type, false)));
indexingFunction->addParameter(TConstParameter(kIndexName, kIndexType));
new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
mIndexedVecAndMatrixTypes[type] = indexingFunction;
}
else
......@@ -458,13 +449,16 @@ bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *nod
indexedWriteFunction =
new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
StaticType::GetBasic<EbtVoid>(), false);
indexedWriteFunction->addParameter(
TConstParameter(kBaseName, GetBaseType(type, true)));
indexedWriteFunction->addParameter(TConstParameter(kIndexName, kIndexType));
indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName,
GetBaseType(type, true),
SymbolType::AngleInternal));
indexedWriteFunction->addParameter(new TVariable(
mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
TType *valueType = GetFieldType(type);
valueType->setQualifier(EvqIn);
indexedWriteFunction->addParameter(
TConstParameter(kValueName, static_cast<const TType *>(valueType)));
indexedWriteFunction->addParameter(new TVariable(
mSymbolTable, kValueName, static_cast<const TType *>(valueType),
SymbolType::AngleInternal));
mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
}
else
......
......@@ -30,7 +30,7 @@ class CollectVariableRefCountsTraverser : public TIntermTraverser
void visitSymbol(TIntermSymbol *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
private:
void incrementStructTypeRefCount(const TType &type);
......@@ -108,11 +108,14 @@ bool CollectVariableRefCountsTraverser::visitAggregate(Visit visit, TIntermAggre
return true;
}
bool CollectVariableRefCountsTraverser::visitFunctionPrototype(Visit visit,
TIntermFunctionPrototype *node)
void CollectVariableRefCountsTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
incrementStructTypeRefCount(node->getType());
return true;
size_t paramCount = node->getFunction()->getParamCount();
for (size_t i = 0; i < paramCount; ++i)
{
incrementStructTypeRefCount(node->getFunction()->getParam(i)->getType());
}
}
// Traverser that removes all unreferenced variables on one traversal.
......
......@@ -164,7 +164,7 @@ TFunction::TFunction(TSymbolTable *symbolTable,
ASSERT(name != nullptr || symbolType == SymbolType::AngleInternal);
}
void TFunction::addParameter(const TConstParameter &p)
void TFunction::addParameter(const TVariable *p)
{
ASSERT(mParametersVector);
mParametersVector->push_back(p);
......@@ -189,7 +189,7 @@ ImmutableString TFunction::buildMangledName() const
for (size_t i = 0u; i < mParamCount; ++i)
{
newName += mParameters[i].type->getMangledName();
newName += mParameters[i]->getType().getMangledName();
}
return ImmutableString(newName);
}
......
......@@ -171,38 +171,20 @@ class TInterfaceBlock : public TSymbol, public TFieldListCollection
// Note that we only record matrix packing on a per-field granularity.
};
// Immutable version of TParameter.
struct TConstParameter
{
POOL_ALLOCATOR_NEW_DELETE();
TConstParameter() : name(""), type(nullptr) {}
explicit TConstParameter(const ImmutableString &n) : name(n), type(nullptr) {}
constexpr explicit TConstParameter(const TType *t) : name(""), type(t) {}
TConstParameter(const ImmutableString &n, const TType *t) : name(n), type(t) {}
// Both constructor arguments must be const.
TConstParameter(ImmutableString *n, TType *t) = delete;
TConstParameter(const ImmutableString *n, TType *t) = delete;
TConstParameter(ImmutableString *n, const TType *t) = delete;
const ImmutableString name;
const TType *const type;
};
// The function sub-class of symbols and the parser will need to
// share this definition of a function parameter.
// Parameter class used for parsing user-defined function parameters.
struct TParameter
{
// Destructively converts to TConstParameter.
// Destructively converts to TVariable.
// This method resets name and type to nullptrs to make sure
// their content cannot be modified after the call.
TConstParameter turnToConst()
const TVariable *createVariable(TSymbolTable *symbolTable)
{
const ImmutableString constName(name);
const TType *constType = type;
name = nullptr;
type = nullptr;
return TConstParameter(constName, constType);
return new TVariable(symbolTable, constName, constType,
constName.empty() ? SymbolType::Empty : SymbolType::UserDefined);
}
const char *name; // either pool allocated or static.
......@@ -222,7 +204,7 @@ class TFunction : public TSymbol
bool isFunction() const override { return true; }
void addParameter(const TConstParameter &p);
void addParameter(const TVariable *p);
void shareParameters(const TFunction &parametersSource);
ImmutableString getMangledName() const override
......@@ -244,7 +226,7 @@ class TFunction : public TSymbol
bool hasPrototypeDeclaration() const { return mHasPrototypeDeclaration; }
size_t getParamCount() const { return mParamCount; }
const TConstParameter &getParam(size_t i) const { return mParameters[i]; }
const TVariable *getParam(size_t i) const { return mParameters[i]; }
bool isKnownToNotHaveSideEffects() const { return mKnownToNotHaveSideEffects; }
......@@ -255,7 +237,7 @@ class TFunction : public TSymbol
constexpr TFunction(const TSymbolUniqueId &id,
const ImmutableString &name,
TExtension extension,
const TConstParameter *parameters,
const TVariable *const *parameters,
size_t paramCount,
const TType *retType,
const ImmutableString &mangledName,
......@@ -277,9 +259,9 @@ class TFunction : public TSymbol
private:
ImmutableString buildMangledName() const;
typedef TVector<TConstParameter> TParamVector;
typedef TVector<const TVariable *> TParamVector;
TParamVector *mParametersVector;
const TConstParameter *mParameters;
const TVariable *const *mParameters;
size_t mParamCount;
const TType *const returnType;
mutable ImmutableString mMangledName;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -15,6 +15,33 @@
namespace sh
{
namespace
{
void DisambiguateFunctionNameForParameterType(const TType &paramType,
TString *disambiguatingStringOut)
{
// Parameter types are only added to function names if they are ambiguous according to the
// native HLSL compiler. Other parameter types are not added to function names to avoid
// making function names longer.
if (paramType.getObjectSize() == 4 && paramType.getBasicType() == EbtFloat)
{
// Disambiguation is needed for float2x2 and float4 parameters. These are the only
// built-in types that HLSL thinks are identical. float2x3 and float3x2 are different
// types, for example.
*disambiguatingStringOut += "_" + TypeString(paramType);
}
else if (paramType.getBasicType() == EbtStruct)
{
// Disambiguation is needed for struct parameters, since HLSL thinks that structs with
// the same fields but a different name are identical.
ASSERT(paramType.getStruct()->symbolType() != SymbolType::Empty);
*disambiguatingStringOut += "_" + TypeString(paramType);
}
}
} // anonymous namespace
const char *SamplerString(const TBasicType type)
{
if (IsShadowSampler(type))
......@@ -1025,29 +1052,26 @@ const char *QualifierString(TQualifier qualifier)
return "";
}
TString DisambiguateFunctionName(const TIntermSequence *parameters)
TString DisambiguateFunctionName(const TFunction *func)
{
TString disambiguatingString;
for (auto parameter : *parameters)
{
const TType &paramType = parameter->getAsTyped()->getType();
// Parameter types are only added to function names if they are ambiguous according to the
// native HLSL compiler. Other parameter types are not added to function names to avoid
// making function names longer.
if (paramType.getObjectSize() == 4 && paramType.getBasicType() == EbtFloat)
size_t paramCount = func->getParamCount();
for (size_t i = 0; i < paramCount; ++i)
{
// Disambiguation is needed for float2x2 and float4 parameters. These are the only
// built-in types that HLSL thinks are identical. float2x3 and float3x2 are different
// types, for example.
disambiguatingString += "_" + TypeString(paramType);
DisambiguateFunctionNameForParameterType(func->getParam(i)->getType(),
&disambiguatingString);
}
else if (paramType.getBasicType() == EbtStruct)
return disambiguatingString;
}
TString DisambiguateFunctionName(const TIntermSequence *args)
{
TString disambiguatingString;
for (TIntermNode *arg : *args)
{
// Disambiguation is needed for struct parameters, since HLSL thinks that structs with
// the same fields but a different name are identical.
ASSERT(paramType.getStruct()->symbolType() != SymbolType::Empty);
disambiguatingString += "_" + TypeString(paramType);
}
ASSERT(arg->getAsTyped());
DisambiguateFunctionNameForParameterType(arg->getAsTyped()->getType(),
&disambiguatingString);
}
return disambiguatingString;
}
......
......@@ -19,6 +19,8 @@
namespace sh
{
class TFunction;
// HLSL Texture type for GLSL sampler type and readonly image type.
enum HLSLTextureGroup
{
......@@ -123,7 +125,8 @@ const char *InterpolationString(TQualifier qualifier);
const char *QualifierString(TQualifier qualifier);
// Parameters may need to be included in function names to disambiguate between overloaded
// functions.
TString DisambiguateFunctionName(const TIntermSequence *parameters);
TString DisambiguateFunctionName(const TFunction *func);
TString DisambiguateFunctionName(const TIntermSequence *args);
}
#endif // COMPILER_TRANSLATOR_UTILSHLSL_H_
......@@ -8,6 +8,7 @@
#include "compiler/translator/ValidateMaxParameters.h"
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/Symbol.h"
namespace sh
{
......@@ -18,7 +19,7 @@ bool ValidateMaxParameters(TIntermBlock *root, unsigned int maxParameters)
{
TIntermFunctionDefinition *definition = node->getAsFunctionDefinition();
if (definition != nullptr &&
definition->getFunctionPrototype()->getSequence()->size() > maxParameters)
definition->getFunctionPrototype()->getFunction()->getParamCount() > maxParameters)
{
return false;
}
......
......@@ -105,15 +105,16 @@ bool TVersionGLSL::visitInvariantDeclaration(Visit, TIntermInvariantDeclaration
return true;
}
bool TVersionGLSL::visitFunctionPrototype(Visit, TIntermFunctionPrototype *node)
void TVersionGLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
const TIntermSequence &params = *(node->getSequence());
for (TIntermSequence::const_iterator iter = params.begin(); iter != params.end(); ++iter)
size_t paramCount = node->getFunction()->getParamCount();
for (size_t i = 0; i < paramCount; ++i)
{
const TIntermTyped *param = (*iter)->getAsTyped();
if (param->isArray())
const TVariable *param = node->getFunction()->getParam(i);
const TType &type = param->getType();
if (type.isArray())
{
TQualifier qualifier = param->getQualifier();
TQualifier qualifier = type.getQualifier();
if ((qualifier == EvqOut) || (qualifier == EvqInOut))
{
ensureVersionIsAtLeast(GLSL_VERSION_120);
......@@ -121,8 +122,6 @@ bool TVersionGLSL::visitFunctionPrototype(Visit, TIntermFunctionPrototype *node)
}
}
}
// Fully processed. No need to visit children.
return false;
}
bool TVersionGLSL::visitAggregate(Visit, TIntermAggregate *node)
......
......@@ -62,7 +62,7 @@ class TVersionGLSL : public TIntermTraverser
void visitSymbol(TIntermSymbol *node) override;
bool visitAggregate(Visit, TIntermAggregate *node) override;
bool visitInvariantDeclaration(Visit, TIntermInvariantDeclaration *node) override;
bool visitFunctionPrototype(Visit, TIntermFunctionPrototype *node) override;
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitDeclaration(Visit, TIntermDeclaration *node) override;
private:
......
......@@ -121,10 +121,23 @@ const int TSymbolTable::kLastBuiltInId = {last_builtin_id};
namespace BuiltInName
{{
constexpr const ImmutableString _empty("");
{name_declarations}
}} // namespace BuiltInName
// TODO(oetuaho): Would be nice to make this a class instead of a namespace so that we could friend
// this from TVariable. Now symbol constructors taking an id have to be public even though they're
// not supposed to be accessible from outside of here. http://anglebug.com/2390
namespace BuiltInVariable
{{
{variable_declarations}
{get_variable_definitions}
}}; // namespace BuiltInVariable
namespace BuiltInParameters
{{
......@@ -140,18 +153,6 @@ namespace UnmangledBuiltIns
}} // namespace UnmangledBuiltIns
// TODO(oetuaho): Would be nice to make this a class instead of a namespace so that we could friend
// this from TVariable. Now symbol constructors taking an id have to be public even though they're
// not supposed to be accessible from outside of here. http://anglebug.com/2390
namespace BuiltInVariable
{{
{variable_declarations}
{get_variable_definitions}
}}; // namespace BuiltInVariable
// TODO(oetuaho): Would be nice to make this a class instead of a namespace so that we could friend
// this from TFunction. Now symbol constructors taking an id have to be public even though they're
// not supposed to be accessible from outside of here. http://anglebug.com/2390
namespace BuiltInFunction
......@@ -630,6 +631,16 @@ def get_unique_identifier_name(function_name, parameters):
unique_name += param.get_mangled_name()
return unique_name
def get_variable_name_to_store_parameter(param):
unique_name = 'pt'
if 'qualifier' in param.data:
if param.data['qualifier'] == 'Out':
unique_name += '_o_'
if param.data['qualifier'] == 'InOut':
unique_name += '_io_'
unique_name += param.get_mangled_name()
return unique_name
def get_variable_name_to_store_parameters(parameters):
if len(parameters) == 0:
return 'empty'
......@@ -643,6 +654,10 @@ def get_variable_name_to_store_parameters(parameters):
unique_name += param.get_mangled_name()
return unique_name
def define_constexpr_variable(template_args):
template_variable_declaration = 'constexpr const TVariable kVar_{name_with_suffix}(BuiltInId::{name_with_suffix}, BuiltInName::{name}, SymbolType::BuiltIn, TExtension::{extension}, {type});'
variable_declarations.append(template_variable_declaration.format(**template_args))
def gen_function_variants(function_name, function_props):
function_variants = []
parameters = get_parameters(function_props)
......@@ -688,6 +703,7 @@ def gen_function_variants(function_name, function_props):
return function_variants
defined_function_variants = set()
defined_parameter_names = set()
def process_single_function_group(condition, group_name, group):
global id_counter
......@@ -751,15 +767,29 @@ def process_single_function_group(condition, group_name, group):
parameters_list = []
for param in parameters:
template_parameter = 'TConstParameter({param_type})'
parameters_list.append(template_parameter.format(param_type = param.get_statictype_string()))
unique_param_name = get_variable_name_to_store_parameter(param)
param_template_args = {
'name': '_empty',
'name_with_suffix': unique_param_name,
'type': param.get_statictype_string(),
'extension': 'UNDEFINED'
}
if unique_param_name not in defined_parameter_names:
id_counter += 1
param_template_args['id'] = id_counter
template_builtin_id_declaration = ' static constexpr const TSymbolUniqueId {name_with_suffix} = TSymbolUniqueId({id});'
builtin_id_declarations.append(template_builtin_id_declaration.format(**param_template_args))
define_constexpr_variable(param_template_args)
defined_parameter_names.add(unique_param_name)
parameters_list.append('&BuiltInVariable::kVar_{name_with_suffix}'.format(**param_template_args));
template_args['parameters_var_name'] = get_variable_name_to_store_parameters(parameters)
if len(parameters) > 0:
template_args['parameters_list'] = ', '.join(parameters_list)
template_parameter_list_declaration = 'constexpr const TConstParameter {parameters_var_name}[{param_count}] = {{ {parameters_list} }};'
template_parameter_list_declaration = 'constexpr const TVariable *{parameters_var_name}[{param_count}] = {{ {parameters_list} }};'
parameter_declarations.add(template_parameter_list_declaration.format(**template_args))
else:
template_parameter_list_declaration = 'constexpr const TConstParameter *{parameters_var_name} = nullptr;'
template_parameter_list_declaration = 'constexpr const TVariable **{parameters_var_name} = nullptr;'
parameter_declarations.add(template_parameter_list_declaration.format(**template_args))
template_function_declaration = 'constexpr const TFunction kFunction_{unique_name}(BuiltInId::{unique_name}, BuiltInName::{name_with_suffix}, TExtension::{extension}, BuiltInParameters::{parameters_var_name}, {param_count}, {return_type}, BuiltInName::{unique_name}, EOp{op}, {known_to_not_have_side_effects});'
......@@ -891,8 +921,7 @@ def process_single_variable_group(group_name, group):
else:
# Handle variables that can be stored as constexpr TVariable like
# gl_Position, gl_FragColor etc.
template_variable_declaration = 'constexpr const TVariable kVar_{name_with_suffix}(BuiltInId::{name_with_suffix}, BuiltInName::{name}, SymbolType::BuiltIn, TExtension::{extension}, {type});'
variable_declarations.append(template_variable_declaration.format(**template_args))
define_constexpr_variable(template_args)
template_get_variable_declaration = 'const TVariable *{name_with_suffix}();'
get_variable_declarations.append(template_get_variable_declaration.format(**template_args))
......
......@@ -650,7 +650,7 @@ function_header_with_parameters
$$ = $1;
if ($2.type->getBasicType() != EbtVoid)
{
$1->addParameter($2.turnToConst());
$1->addParameter($2.createVariable(&context->symbolTable));
}
}
| function_header_with_parameters COMMA parameter_declaration {
......@@ -664,7 +664,7 @@ function_header_with_parameters
}
else
{
$1->addParameter($3.turnToConst());
$1->addParameter($3.createVariable(&context->symbolTable));
}
}
;
......
......@@ -3268,7 +3268,7 @@ yyreduce:
(yyval.interm.function) = (yyvsp[-1].interm.function);
if ((yyvsp[0].interm.param).type->getBasicType() != EbtVoid)
{
(yyvsp[-1].interm.function)->addParameter((yyvsp[0].interm.param).turnToConst());
(yyvsp[-1].interm.function)->addParameter((yyvsp[0].interm.param).createVariable(&context->symbolTable));
}
}
......@@ -3287,7 +3287,7 @@ yyreduce:
}
else
{
(yyvsp[-2].interm.function)->addParameter((yyvsp[0].interm.param).turnToConst());
(yyvsp[-2].interm.function)->addParameter((yyvsp[0].interm.param).createVariable(&context->symbolTable));
}
}
......
......@@ -70,7 +70,8 @@ class IntermNodeTest : public testing::Test
for (TIntermNode *arg : args)
{
const TType *type = new TType(arg->getAsTyped()->getType());
func->addParameter(TConstParameter(type));
func->addParameter(new TVariable(&symbolTable, ImmutableString("param"), type,
SymbolType::UserDefined));
}
return func;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment