Commit ea22b7a5 by Olli Etuaho Committed by Commit Bot

Constant fold array indexing and comparison

A virtual function to get the constant value of an AST node is added to TIntermTyped. This way a constant value can be retrieved conveniently from multiple different types of nodes. TIntermSymbol nodes pointing to a const variable can return the value associated with the variable, constructor nodes can build a constant value from their arguments, and indexing nodes can index into a constant array. This enables constant folding operations on constant arrays, while making sure that large amounts of data are not duplicated in the output shader. When folding an operation makes sense, the values of the arguments can be retrieved by using the new TIntermTyped::getConstantValue(). When folding an operation would result in duplicating data, the AST can just be left to be written out as is. For example, if the code contains a constant array of arrays, indexing into individual elements of the inner arrays can be folded, but indexing the top level array is left in place and not replaced with duplicated array literals. Constant folding is supported for indexing and comparisons of arrays. In case constant arrays are only referenced through foldable operations, the variable declarations will be pruned from the AST by the RemoveUnreferencedVariables step. BUG=angleproject:2298 TEST=angle_unittests Change-Id: I5b3be237b7e9fdba56aa9bf0a41b691f4d8f01eb Reviewed-on: https://chromium-review.googlesource.com/850973Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent ed1390f8
...@@ -45,14 +45,12 @@ void GetDeferredInitializers(TIntermDeclaration *declaration, ...@@ -45,14 +45,12 @@ void GetDeferredInitializers(TIntermDeclaration *declaration,
ASSERT(symbolNode); ASSERT(symbolNode);
TIntermTyped *expression = init->getRight(); TIntermTyped *expression = init->getRight();
if ((expression->getQualifier() != EvqConst || if (expression->getQualifier() != EvqConst || !expression->hasConstantValue())
(expression->getAsConstantUnion() == nullptr &&
!expression->isConstructorWithOnlyConstantUnionParameters())))
{ {
// For variables which are not constant, defer their real initialization until // For variables which are not constant, defer their real initialization until
// after we initialize uniforms. // after we initialize uniforms.
// Deferral is done also in any cases where the variable has not been constant // Deferral is done also in any cases where the variable can not be converted to a
// folded, since otherwise there's a chance that HLSL output will generate extra // constant union, since otherwise there's a chance that HLSL output will generate extra
// statements from the initializer expression. // statements from the initializer expression.
// Change const global to a regular global if its initialization is deferred. // Change const global to a regular global if its initialization is deferred.
......
...@@ -81,7 +81,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -81,7 +81,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
} }
ASSERT(constantExponent->getBasicType() == EbtFloat); ASSERT(constantExponent->getBasicType() == EbtFloat);
float exponentValue = constantExponent->getUnionArrayPointer()->getFConst(); float exponentValue = constantExponent->getConstantValue()->getFConst();
// Test 2: exponentValue is in the problematic range. // Test 2: exponentValue is in the problematic range.
if (exponentValue < -5.0f || exponentValue > 9.0f) if (exponentValue < -5.0f || exponentValue > 9.0f)
......
...@@ -131,6 +131,12 @@ class TIntermTyped : public TIntermNode ...@@ -131,6 +131,12 @@ class TIntermTyped : public TIntermNode
virtual TIntermTyped *fold(TDiagnostics *diagnostics) { return this; } virtual TIntermTyped *fold(TDiagnostics *diagnostics) { return this; }
// getConstantValue() returns the constant value that this node represents, if any. It
// should only be used after nodes have been replaced with their folded versions returned
// from fold(). hasConstantValue() returns true if getConstantValue() will return a value.
virtual bool hasConstantValue() const;
virtual const TConstantUnion *getConstantValue() const;
// True if executing the expression represented by this node affects state, like values of // True if executing the expression represented by this node affects state, like values of
// variables. False if the executing the expression only computes its return value without // variables. False if the executing the expression only computes its return value without
// affecting state. May return true conservatively. // affecting state. May return true conservatively.
...@@ -161,8 +167,6 @@ class TIntermTyped : public TIntermNode ...@@ -161,8 +167,6 @@ class TIntermTyped : public TIntermNode
unsigned int getOutermostArraySize() const { return mType.getOutermostArraySize(); } unsigned int getOutermostArraySize() const { return mType.getOutermostArraySize(); }
bool isConstructorWithOnlyConstantUnionParameters();
protected: protected:
TType mType; TType mType;
...@@ -241,6 +245,9 @@ class TIntermSymbol : public TIntermTyped ...@@ -241,6 +245,9 @@ class TIntermSymbol : public TIntermTyped
TIntermTyped *deepCopy() const override { return new TIntermSymbol(*this); } TIntermTyped *deepCopy() const override { return new TIntermSymbol(*this); }
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
bool hasSideEffects() const override { return false; } bool hasSideEffects() const override { return false; }
const TSymbolUniqueId &uniqueId() const; const TSymbolUniqueId &uniqueId() const;
...@@ -302,9 +309,10 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -302,9 +309,10 @@ class TIntermConstantUnion : public TIntermTyped
TIntermTyped *deepCopy() const override { return new TIntermConstantUnion(*this); } TIntermTyped *deepCopy() const override { return new TIntermConstantUnion(*this); }
bool hasSideEffects() const override { return false; } bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
const TConstantUnion *getUnionArrayPointer() const { return mUnionArrayPointer; } bool hasSideEffects() const override { return false; }
int getIConst(size_t index) const int getIConst(size_t index) const
{ {
...@@ -334,15 +342,20 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -334,15 +342,20 @@ class TIntermConstantUnion : public TIntermTyped
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; } bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; }
TConstantUnion *foldBinary(TOperator op,
TIntermConstantUnion *rightNode,
TDiagnostics *diagnostics,
const TSourceLoc &line);
const TConstantUnion *foldIndexing(int index);
TConstantUnion *foldUnaryNonComponentWise(TOperator op); TConstantUnion *foldUnaryNonComponentWise(TOperator op);
TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics); TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics);
static TConstantUnion *FoldAggregateConstructor(TIntermAggregate *aggregate); static const TConstantUnion *FoldBinary(TOperator op,
const TConstantUnion *leftArray,
const TType &leftType,
const TConstantUnion *rightArray,
const TType &rightType,
TDiagnostics *diagnostics,
const TSourceLoc &line);
static const TConstantUnion *FoldIndexing(const TType &type,
const TConstantUnion *constArray,
int index);
static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate, static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate,
TDiagnostics *diagnostics); TDiagnostics *diagnostics);
...@@ -430,6 +443,9 @@ class TIntermBinary : public TIntermOperator ...@@ -430,6 +443,9 @@ class TIntermBinary : public TIntermOperator
TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); } TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); }
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right); static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right);
static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right); static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
static TQualifier GetCommaQualifier(int shaderVersion, static TQualifier GetCommaQualifier(int shaderVersion,
...@@ -553,6 +569,9 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase ...@@ -553,6 +569,9 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TIntermAggregate *shallowCopy() const; TIntermAggregate *shallowCopy() const;
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
TIntermAggregate *getAsAggregate() override { return this; } TIntermAggregate *getAsAggregate() override { return this; }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override; bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
......
...@@ -441,7 +441,7 @@ void TOutputGLSLBase::visitSymbol(TIntermSymbol *node) ...@@ -441,7 +441,7 @@ void TOutputGLSLBase::visitSymbol(TIntermSymbol *node)
void TOutputGLSLBase::visitConstantUnion(TIntermConstantUnion *node) void TOutputGLSLBase::visitConstantUnion(TIntermConstantUnion *node)
{ {
writeConstantUnion(node->getType(), node->getUnionArrayPointer()); writeConstantUnion(node->getType(), node->getConstantValue());
} }
bool TOutputGLSLBase::visitSwizzle(Visit visit, TIntermSwizzle *node) bool TOutputGLSLBase::visitSwizzle(Visit visit, TIntermSwizzle *node)
......
...@@ -419,7 +419,8 @@ void OutputHLSL::header(TInfoSinkBase &out, ...@@ -419,7 +419,8 @@ void OutputHLSL::header(TInfoSinkBase &out,
// Program linking depends on this exact format // Program linking depends on this exact format
varyings += "static " + InterpolationString(type.getQualifier()) + " " + TypeString(type) + varyings += "static " + InterpolationString(type.getQualifier()) + " " + TypeString(type) +
" " + Decorate(name) + ArrayString(type) + " = " + initializer(type) + ";\n"; " " + Decorate(name) + ArrayString(type) + " = " + zeroInitializer(type) +
";\n";
} }
for (const auto &attribute : mReferencedAttributes) for (const auto &attribute : mReferencedAttributes)
...@@ -428,7 +429,7 @@ void OutputHLSL::header(TInfoSinkBase &out, ...@@ -428,7 +429,7 @@ void OutputHLSL::header(TInfoSinkBase &out,
const TString &name = attribute.second->name(); const TString &name = attribute.second->name();
attributes += "static " + TypeString(type) + " " + Decorate(name) + ArrayString(type) + attributes += "static " + TypeString(type) + " " + Decorate(name) + ArrayString(type) +
" = " + initializer(type) + ";\n"; " = " + zeroInitializer(type) + ";\n";
} }
out << mStructureHLSL->structsHeader(); out << mStructureHLSL->structsHeader();
...@@ -501,7 +502,8 @@ void OutputHLSL::header(TInfoSinkBase &out, ...@@ -501,7 +502,8 @@ void OutputHLSL::header(TInfoSinkBase &out,
const TType &variableType = outputVariable.second->getType(); const TType &variableType = outputVariable.second->getType();
out << "static " + TypeString(variableType) + " out_" + variableName + out << "static " + TypeString(variableType) + " out_" + variableName +
ArrayString(variableType) + " = " + initializer(variableType) + ";\n"; ArrayString(variableType) + " = " + zeroInitializer(variableType) +
";\n";
} }
} }
else else
...@@ -1156,23 +1158,22 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1156,23 +1158,22 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{ {
TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode(); TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
ASSERT(symbolNode); ASSERT(symbolNode);
TIntermTyped *expression = node->getRight(); TIntermTyped *initializer = node->getRight();
// Global initializers must be constant at this point. // Global initializers must be constant at this point.
ASSERT(symbolNode->getQualifier() != EvqGlobal || ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
canWriteAsHLSLLiteral(expression));
// GLSL allows to write things like "float x = x;" where a new variable x is defined // GLSL allows to write things like "float x = x;" where a new variable x is defined
// and the value of an existing variable x is assigned. HLSL uses C semantics (the // and the value of an existing variable x is assigned. HLSL uses C semantics (the
// new variable is created before the assignment is evaluated), so we need to // new variable is created before the assignment is evaluated), so we need to
// convert // convert
// this to "float t = x, x = t;". // this to "float t = x, x = t;".
if (writeSameSymbolInitializer(out, symbolNode, expression)) if (writeSameSymbolInitializer(out, symbolNode, initializer))
{ {
// Skip initializing the rest of the expression // Skip initializing the rest of the expression
return false; return false;
} }
else if (writeConstantInitialization(out, symbolNode, expression)) else if (writeConstantInitialization(out, symbolNode, initializer))
{ {
return false; return false;
} }
...@@ -1838,7 +1839,7 @@ bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node) ...@@ -1838,7 +1839,7 @@ bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
{ {
symbol->traverse(this); symbol->traverse(this);
out << ArrayString(symbol->getType()); out << ArrayString(symbol->getType());
out << " = " + initializer(symbol->getType()); out << " = " + zeroInitializer(symbol->getType());
} }
else else
{ {
...@@ -2240,7 +2241,7 @@ bool OutputHLSL::visitCase(Visit visit, TIntermCase *node) ...@@ -2240,7 +2241,7 @@ bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node) void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
writeConstantUnion(out, node->getType(), node->getUnionArrayPointer()); writeConstantUnion(out, node->getType(), node->getConstantValue());
} }
bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node) bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
...@@ -2712,7 +2713,7 @@ TString OutputHLSL::argumentString(const TIntermSymbol *symbol) ...@@ -2712,7 +2713,7 @@ TString OutputHLSL::argumentString(const TIntermSymbol *symbol)
return argString.str(); return argString.str();
} }
TString OutputHLSL::initializer(const TType &type) TString OutputHLSL::zeroInitializer(const TType &type)
{ {
TString string; TString string;
...@@ -2763,6 +2764,8 @@ const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out, ...@@ -2763,6 +2764,8 @@ const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
const TType &type, const TType &type,
const TConstantUnion *const constUnion) const TConstantUnion *const constUnion)
{ {
ASSERT(!type.isArray());
const TConstantUnion *constUnionIterated = constUnion; const TConstantUnion *constUnionIterated = constUnion;
const TStructure *structure = type.getStruct(); const TStructure *structure = type.getStruct();
...@@ -2841,50 +2844,17 @@ bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out, ...@@ -2841,50 +2844,17 @@ bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
return false; return false;
} }
bool OutputHLSL::canWriteAsHLSLLiteral(TIntermTyped *expression)
{
// We support writing constant unions and constructors that only take constant unions as
// parameters as HLSL literals.
return !expression->getType().isArrayOfArrays() &&
(expression->getAsConstantUnion() ||
expression->isConstructorWithOnlyConstantUnionParameters());
}
bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out, bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
TIntermSymbol *symbolNode, TIntermSymbol *symbolNode,
TIntermTyped *initializer) TIntermTyped *initializer)
{ {
if (canWriteAsHLSLLiteral(initializer)) if (initializer->hasConstantValue())
{ {
symbolNode->traverse(this); symbolNode->traverse(this);
ASSERT(!symbolNode->getType().isArrayOfArrays()); out << ArrayString(symbolNode->getType());
if (symbolNode->getType().isArray())
{
out << "[" << symbolNode->getType().getOutermostArraySize() << "]";
}
out << " = {"; out << " = {";
if (initializer->getAsConstantUnion()) writeConstantUnionArray(out, initializer->getConstantValue(),
{ initializer->getType().getObjectSize());
TIntermConstantUnion *nodeConst = initializer->getAsConstantUnion();
const TConstantUnion *constUnion = nodeConst->getUnionArrayPointer();
writeConstantUnionArray(out, constUnion, nodeConst->getType().getObjectSize());
}
else
{
TIntermAggregate *constructor = initializer->getAsAggregate();
ASSERT(constructor != nullptr);
for (TIntermNode *&node : *constructor->getSequence())
{
TIntermConstantUnion *nodeConst = node->getAsConstantUnion();
ASSERT(nodeConst);
const TConstantUnion *constUnion = nodeConst->getUnionArrayPointer();
writeConstantUnionArray(out, constUnion, nodeConst->getType().getObjectSize());
if (node != constructor->getSequence()->back())
{
out << ", ";
}
}
}
out << "}"; out << "}";
return true; return true;
} }
......
...@@ -62,7 +62,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -62,7 +62,7 @@ class OutputHLSL : public TIntermTraverser
const std::map<std::string, unsigned int> &getUniformBlockRegisterMap() const; const std::map<std::string, unsigned int> &getUniformBlockRegisterMap() const;
const std::map<std::string, unsigned int> &getUniformRegisterMap() const; const std::map<std::string, unsigned int> &getUniformRegisterMap() const;
static TString initializer(const TType &type); static TString zeroInitializer(const TType &type);
TInfoSinkBase &getInfoSink() TInfoSinkBase &getInfoSink()
{ {
...@@ -70,8 +70,6 @@ class OutputHLSL : public TIntermTraverser ...@@ -70,8 +70,6 @@ class OutputHLSL : public TIntermTraverser
return *mInfoSinkStack.top(); return *mInfoSinkStack.top();
} }
static bool canWriteAsHLSLLiteral(TIntermTyped *expression);
protected: protected:
void header(TInfoSinkBase &out, void header(TInfoSinkBase &out,
const std::vector<MappedStruct> &std140Structs, const std::vector<MappedStruct> &std140Structs,
......
...@@ -273,7 +273,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -273,7 +273,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
OutputTreeText(mOut, intermConstantUnion, mDepth + 1); OutputTreeText(mOut, intermConstantUnion, mDepth + 1);
// The following code finds the field name from the constant union // The following code finds the field name from the constant union
const TConstantUnion *constantUnion = intermConstantUnion->getUnionArrayPointer(); const TConstantUnion *constantUnion = intermConstantUnion->getConstantValue();
const TStructure *structure = node->getLeft()->getType().getStruct(); const TStructure *structure = node->getLeft()->getType().getStruct();
const TInterfaceBlock *interfaceBlock = node->getLeft()->getType().getInterfaceBlock(); const TInterfaceBlock *interfaceBlock = node->getLeft()->getType().getInterfaceBlock();
ASSERT(structure || interfaceBlock); ASSERT(structure || interfaceBlock);
...@@ -557,10 +557,10 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node) ...@@ -557,10 +557,10 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node)
for (size_t i = 0; i < size; i++) for (size_t i = 0; i < size; i++)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, mDepth);
switch (node->getUnionArrayPointer()[i].getType()) switch (node->getConstantValue()[i].getType())
{ {
case EbtBool: case EbtBool:
if (node->getUnionArrayPointer()[i].getBConst()) if (node->getConstantValue()[i].getBConst())
mOut << "true"; mOut << "true";
else else
mOut << "false"; mOut << "false";
...@@ -571,20 +571,20 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node) ...@@ -571,20 +571,20 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node)
mOut << "\n"; mOut << "\n";
break; break;
case EbtFloat: case EbtFloat:
mOut << node->getUnionArrayPointer()[i].getFConst(); mOut << node->getConstantValue()[i].getFConst();
mOut << " (const float)\n"; mOut << " (const float)\n";
break; break;
case EbtInt: case EbtInt:
mOut << node->getUnionArrayPointer()[i].getIConst(); mOut << node->getConstantValue()[i].getIConst();
mOut << " (const int)\n"; mOut << " (const int)\n";
break; break;
case EbtUInt: case EbtUInt:
mOut << node->getUnionArrayPointer()[i].getUConst(); mOut << node->getConstantValue()[i].getUConst();
mOut << " (const uint)\n"; mOut << " (const uint)\n";
break; break;
case EbtYuvCscStandardEXT: case EbtYuvCscStandardEXT:
mOut << getYuvCscStandardEXTString( mOut << getYuvCscStandardEXTString(
node->getUnionArrayPointer()[i].getYuvCscStandardEXTConst()); node->getConstantValue()[i].getYuvCscStandardEXTConst());
mOut << " (const yuvCscStandardEXT)\n"; mOut << " (const yuvCscStandardEXT)\n";
break; break;
default: default:
......
...@@ -1870,7 +1870,7 @@ TIntermTyped *TParseContext::parseVariableIdentifier(const TSourceLoc &location, ...@@ -1870,7 +1870,7 @@ TIntermTyped *TParseContext::parseVariableIdentifier(const TSourceLoc &location,
const TType &variableType = variable->getType(); const TType &variableType = variable->getType();
TIntermTyped *node = nullptr; TIntermTyped *node = nullptr;
if (variable->getConstPointer()) if (variable->getConstPointer() && variableType.canReplaceWithConstantUnion())
{ {
const TConstantUnion *constArray = variable->getConstPointer(); const TConstantUnion *constArray = variable->getConstPointer();
node = new TIntermConstantUnion(constArray, variableType); node = new TIntermConstantUnion(constArray, variableType);
...@@ -1989,24 +1989,13 @@ bool TParseContext::executeInitializer(const TSourceLoc &line, ...@@ -1989,24 +1989,13 @@ bool TParseContext::executeInitializer(const TSourceLoc &line,
return false; return false;
} }
// Save the constant folded value to the variable if possible. For example array // Save the constant folded value to the variable if possible.
// initializers are not folded, since that way copying the array literal to multiple places const TConstantUnion *constArray = initializer->getConstantValue();
// in the shader is avoided.
// TODO(oetuaho@nvidia.com): Consider constant folding array initialization in cases where
// it would be beneficial.
if (initializer->getAsConstantUnion())
{
variable->shareConstPointer(initializer->getAsConstantUnion()->getUnionArrayPointer());
ASSERT(*initNode == nullptr);
return true;
}
else if (initializer->getAsSymbolNode())
{
const TVariable &var = initializer->getAsSymbolNode()->variable();
const TConstantUnion *constArray = var.getConstPointer();
if (constArray) if (constArray)
{ {
variable->shareConstPointer(constArray); variable->shareConstPointer(constArray);
if (initializer->getType().canReplaceWithConstantUnion())
{
ASSERT(*initNode == nullptr); ASSERT(*initNode == nullptr);
return true; return true;
} }
...@@ -4071,7 +4060,7 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression, ...@@ -4071,7 +4060,7 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
TIntermBinary *node = TIntermBinary *node =
new TIntermBinary(EOpIndexDirect, baseExpression, indexExpression); new TIntermBinary(EOpIndexDirect, baseExpression, indexExpression);
node->setLine(location); node->setLine(location);
return node->fold(mDiagnostics); return expressionOrFoldedResult(node);
} }
} }
...@@ -4154,7 +4143,7 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre ...@@ -4154,7 +4143,7 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre
TIntermBinary *node = TIntermBinary *node =
new TIntermBinary(EOpIndexDirectStruct, baseExpression, index); new TIntermBinary(EOpIndexDirectStruct, baseExpression, index);
node->setLine(dotLocation); node->setLine(dotLocation);
return node->fold(mDiagnostics); return expressionOrFoldedResult(node);
} }
else else
{ {
...@@ -5357,9 +5346,7 @@ TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op, ...@@ -5357,9 +5346,7 @@ TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
TIntermBinary *node = new TIntermBinary(op, left, right); TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc); node->setLine(loc);
return expressionOrFoldedResult(node);
// See if we can fold constants.
return node->fold(mDiagnostics);
} }
TIntermTyped *TParseContext::addBinaryMath(TOperator op, TIntermTyped *TParseContext::addBinaryMath(TOperator op,
...@@ -5633,7 +5620,7 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall) ...@@ -5633,7 +5620,7 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall)
{ {
ASSERT(offsetConstantUnion->getBasicType() == EbtInt); ASSERT(offsetConstantUnion->getBasicType() == EbtInt);
size_t size = offsetConstantUnion->getType().getObjectSize(); size_t size = offsetConstantUnion->getType().getObjectSize();
const TConstantUnion *values = offsetConstantUnion->getUnionArrayPointer(); const TConstantUnion *values = offsetConstantUnion->getConstantValue();
int minOffsetValue = useTextureGatherOffsetConstraints ? mMinProgramTextureGatherOffset int minOffsetValue = useTextureGatherOffsetConstraints ? mMinProgramTextureGatherOffset
: mMinProgramTexelOffset; : mMinProgramTexelOffset;
int maxOffsetValue = useTextureGatherOffsetConstraints ? mMaxProgramTextureGatherOffset int maxOffsetValue = useTextureGatherOffsetConstraints ? mMaxProgramTextureGatherOffset
......
...@@ -236,7 +236,7 @@ bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TInterm ...@@ -236,7 +236,7 @@ bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TInterm
// We can only remove variables that are not a part of the shader interface. // We can only remove variables that are not a part of the shader interface.
TQualifier qualifier = declarator->getQualifier(); TQualifier qualifier = declarator->getQualifier();
if (qualifier != EvqTemporary && qualifier != EvqGlobal) if (qualifier != EvqTemporary && qualifier != EvqGlobal && qualifier != EvqConst)
{ {
return true; return true;
} }
......
...@@ -55,7 +55,7 @@ bool SeparateArrayInitTraverser::visitDeclaration(Visit, TIntermDeclaration *nod ...@@ -55,7 +55,7 @@ bool SeparateArrayInitTraverser::visitDeclaration(Visit, TIntermDeclaration *nod
if (initNode != nullptr && initNode->getOp() == EOpInitialize) if (initNode != nullptr && initNode->getOp() == EOpInitialize)
{ {
TIntermTyped *initializer = initNode->getRight(); TIntermTyped *initializer = initNode->getRight();
if (initializer->isArray() && !sh::OutputHLSL::canWriteAsHLSLLiteral(initializer)) if (initializer->isArray() && !initializer->hasConstantValue())
{ {
// We rely on that array declarations have been isolated to single declarations. // We rely on that array declarations have been isolated to single declarations.
ASSERT(sequence->size() == 1); ASSERT(sequence->size() == 1);
......
...@@ -452,6 +452,27 @@ bool TType::isStructureContainingSamplers() const ...@@ -452,6 +452,27 @@ bool TType::isStructureContainingSamplers() const
return mStructure ? mStructure->containsSamplers() : false; return mStructure ? mStructure->containsSamplers() : false;
} }
bool TType::canReplaceWithConstantUnion() const
{
if (isArray())
{
return false;
}
if (!mStructure)
{
return true;
}
if (isStructureContainingArrays())
{
return false;
}
if (getObjectSize() > 16)
{
return false;
}
return true;
}
// //
// Recursively generate mangled names. // Recursively generate mangled names.
// //
......
...@@ -305,6 +305,12 @@ class TType ...@@ -305,6 +305,12 @@ class TType
bool isStructSpecifier() const { return mIsStructSpecifier; } bool isStructSpecifier() const { return mIsStructSpecifier; }
// Return true if variables of this type should be replaced with an inline constant value if
// such is available. False will be returned in cases where output doesn't support
// TIntermConstantUnion nodes of the type, or if the type contains a lot of fields and creating
// several copies of it in the output code is undesirable for performance.
bool canReplaceWithConstantUnion() const;
void createSamplerSymbols(const TString &namePrefix, void createSamplerSymbols(const TString &namePrefix,
const TString &apiNamePrefix, const TString &apiNamePrefix,
TVector<const TVariable *> *outputSymbols, TVector<const TVariable *> *outputSymbols,
......
...@@ -1451,3 +1451,159 @@ TEST_F(ConstantFoldingTest, FoldTernaryInsideExpression) ...@@ -1451,3 +1451,159 @@ TEST_F(ConstantFoldingTest, FoldTernaryInsideExpression)
ASSERT_TRUE(constantFoundInAST(3)); ASSERT_TRUE(constantFoundInAST(3));
ASSERT_FALSE(symbolFoundInMain("u")); ASSERT_FALSE(symbolFoundInMain("u"));
} }
// Fold indexing into an array constructor.
TEST_F(ConstantFoldingExpressionTest, FoldArrayConstructorIndexing)
{
const std::string &floatString = "(float[3](-1.0, 1.0, 2.0))[2]";
evaluateFloat(floatString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(1.0f));
ASSERT_TRUE(constantFoundInAST(2.0f));
}
// Fold indexing into an array of arrays constructor.
TEST_F(ConstantFoldingExpressionTest, FoldArrayOfArraysConstructorIndexing)
{
const std::string &floatString = "(float[2][2](float[2](-1.0, 1.0), float[2](2.0, 3.0)))[1][0]";
evaluateFloat(floatString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(1.0f));
ASSERT_FALSE(constantFoundInAST(3.0f));
ASSERT_TRUE(constantFoundInAST(2.0f));
}
// Fold indexing into a named constant array.
TEST_F(ConstantFoldingTest, FoldNamedArrayIndexing)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
const float[3] arr = float[3](-1.0, 1.0, 2.0);
out float my_FragColor;
void main()
{
my_FragColor = arr[1];
})";
compileAssumeSuccess(shaderString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(2.0f));
ASSERT_TRUE(constantFoundInAST(1.0f));
// The variable should be pruned out since after folding the indexing, there are no more
// references to it.
ASSERT_FALSE(symbolFoundInAST("arr"));
}
// Fold indexing into a named constant array of arrays.
TEST_F(ConstantFoldingTest, FoldNamedArrayOfArraysIndexing)
{
const std::string &shaderString =
R"(#version 310 es
precision highp float;
const float[2][2] arr = float[2][2](float[2](-1.0, 1.0), float[2](2.0, 3.0));
out float my_FragColor;
void main()
{
my_FragColor = arr[0][1];
})";
compileAssumeSuccess(shaderString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(2.0f));
ASSERT_FALSE(constantFoundInAST(3.0f));
ASSERT_TRUE(constantFoundInAST(1.0f));
// The variable should be pruned out since after folding the indexing, there are no more
// references to it.
ASSERT_FALSE(symbolFoundInAST("arr"));
}
// Fold indexing into an array constructor where some of the arguments are constant and others are
// non-constant but without side effects.
TEST_F(ConstantFoldingTest, FoldArrayConstructorIndexingWithMixedArguments)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
uniform float u;
out float my_FragColor;
void main()
{
my_FragColor = float[2](u, 1.0)[1];
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(1.0f));
ASSERT_FALSE(constantFoundInAST(1));
ASSERT_FALSE(symbolFoundInMain("u"));
}
// Indexing into an array constructor where some of the arguments have side effects can't be folded.
TEST_F(ConstantFoldingTest, CantFoldArrayConstructorIndexingWithSideEffects)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
out float my_FragColor;
void main()
{
float sideEffectTarget = 0.0;
float f = float[3](sideEffectTarget = 1.0, 1.0, 2.0)[1];
my_FragColor = f + sideEffectTarget;
})";
compileAssumeSuccess(shaderString);
// All of the array constructor arguments should be present in the final AST.
ASSERT_TRUE(constantFoundInAST(1.0f));
ASSERT_TRUE(constantFoundInAST(2.0f));
}
// Fold comparing two array constructors.
TEST_F(ConstantFoldingTest, FoldArrayConstructorEquality)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
out float my_FragColor;
void main()
{
const bool b = (float[3](2.0, 1.0, -1.0) == float[3](2.0, 1.0, -1.0));
my_FragColor = b ? 3.0 : 4.0;
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(3.0f));
ASSERT_FALSE(constantFoundInAST(4.0f));
}
// Fold comparing two named constant arrays.
TEST_F(ConstantFoldingExpressionTest, FoldNamedArrayEquality)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
const float[3] arrA = float[3](-1.0, 1.0, 2.0);
const float[3] arrB = float[3](-1.0, 1.0, 2.0);
out float my_FragColor;
void main()
{
const bool b = (arrA == arrB);
my_FragColor = b ? 3.0 : 4.0;
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(3.0f));
ASSERT_FALSE(constantFoundInAST(4.0f));
}
// Fold comparing two array of arrays constructors.
TEST_F(ConstantFoldingTest, FoldArrayOfArraysConstructorEquality)
{
const std::string &shaderString =
R"(#version 310 es
precision highp float;
out float my_FragColor;
void main()
{
const bool b = (float[2][2](float[2](-1.0, 1.0), float[2](2.0, 3.0)) ==
float[2][2](float[2](-1.0, 1.0), float[2](2.0, 1000.0)));
my_FragColor = b ? 4.0 : 5.0;
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(5.0f));
ASSERT_FALSE(constantFoundInAST(4.0f));
}
...@@ -5680,3 +5680,24 @@ TEST_F(FragmentShaderValidationTest, CommaReturnsNonConstant) ...@@ -5680,3 +5680,24 @@ TEST_F(FragmentShaderValidationTest, CommaReturnsNonConstant)
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog; FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
} }
} }
// Test that the result of indexing into an array constructor with some non-constant arguments is
// not a constant expression.
TEST_F(FragmentShaderValidationTest,
IndexingIntoArrayConstructorWithNonConstantArgumentsIsNotConstantExpression)
{
const std::string &shaderString =
R"(#version 310 es
precision highp float;
uniform float u;
out float my_FragColor;
void main()
{
const float f = float[2](u, 1.0)[1];
my_FragColor = f;
})";
if (compile(shaderString))
{
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
...@@ -17,39 +17,51 @@ using namespace sh; ...@@ -17,39 +17,51 @@ using namespace sh;
void ConstantFoldingExpressionTest::evaluateFloat(const std::string &floatExpression) void ConstantFoldingExpressionTest::evaluateFloat(const std::string &floatExpression)
{ {
// We first assign the expression into a const variable so we can also verify that it gets
// qualified as a constant expression. We then assign that constant expression into my_FragColor
// to make sure that the value is not pruned.
std::stringstream shaderStream; std::stringstream shaderStream;
shaderStream << "#version 310 es\n" shaderStream << "#version 310 es\n"
"precision mediump float;\n" "precision mediump float;\n"
"out float my_FragColor;\n" "out float my_FragColor;\n"
"void main()\n" "void main()\n"
"{\n" "{\n"
<< " my_FragColor = " << floatExpression << ";\n" << " const float f = " << floatExpression << ";\n"
<< "}\n"; << " my_FragColor = f;\n"
"}\n";
compileAssumeSuccess(shaderStream.str()); compileAssumeSuccess(shaderStream.str());
} }
void ConstantFoldingExpressionTest::evaluateInt(const std::string &intExpression) void ConstantFoldingExpressionTest::evaluateInt(const std::string &intExpression)
{ {
// We first assign the expression into a const variable so we can also verify that it gets
// qualified as a constant expression. We then assign that constant expression into my_FragColor
// to make sure that the value is not pruned.
std::stringstream shaderStream; std::stringstream shaderStream;
shaderStream << "#version 310 es\n" shaderStream << "#version 310 es\n"
"precision mediump int;\n" "precision mediump int;\n"
"out int my_FragColor;\n" "out int my_FragColor;\n"
"void main()\n" "void main()\n"
"{\n" "{\n"
<< " my_FragColor = " << intExpression << ";\n" << " const int i = " << intExpression << ";\n"
<< "}\n"; << " my_FragColor = i;\n"
"}\n";
compileAssumeSuccess(shaderStream.str()); compileAssumeSuccess(shaderStream.str());
} }
void ConstantFoldingExpressionTest::evaluateUint(const std::string &uintExpression) void ConstantFoldingExpressionTest::evaluateUint(const std::string &uintExpression)
{ {
// We first assign the expression into a const variable so we can also verify that it gets
// qualified as a constant expression. We then assign that constant expression into my_FragColor
// to make sure that the value is not pruned.
std::stringstream shaderStream; std::stringstream shaderStream;
shaderStream << "#version 310 es\n" shaderStream << "#version 310 es\n"
"precision mediump int;\n" "precision mediump int;\n"
"out uint my_FragColor;\n" "out uint my_FragColor;\n"
"void main()\n" "void main()\n"
"{\n" "{\n"
<< " my_FragColor = " << uintExpression << ";\n" << " const uint u = " << uintExpression << ";\n"
<< "}\n"; << " my_FragColor = u;\n"
"}\n";
compileAssumeSuccess(shaderStream.str()); compileAssumeSuccess(shaderStream.str());
} }
...@@ -56,7 +56,7 @@ class ConstantFinder : public TIntermTraverser ...@@ -56,7 +56,7 @@ class ConstantFinder : public TIntermTraverser
bool found = true; bool found = true;
for (size_t i = 0; i < mConstantVector.size(); i++) for (size_t i = 0; i < mConstantVector.size(); i++)
{ {
if (!isEqual(node->getUnionArrayPointer()[i], mConstantVector[i])) if (!isEqual(node->getConstantValue()[i], mConstantVector[i]))
{ {
found = false; found = false;
break; break;
...@@ -172,6 +172,11 @@ class ConstantFoldingTest : public ShaderCompileTreeTest ...@@ -172,6 +172,11 @@ class ConstantFoldingTest : public ShaderCompileTreeTest
return finder.found(); return finder.found();
} }
bool symbolFoundInAST(const char *symbolName)
{
return FindSymbolNode(mASTRoot, TString(symbolName)) != nullptr;
}
bool symbolFoundInMain(const char *symbolName) bool symbolFoundInMain(const char *symbolName)
{ {
return FindSymbolNode(FindMain(mASTRoot), TString(symbolName)) != nullptr; return FindSymbolNode(FindMain(mASTRoot), TString(symbolName)) != nullptr;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment