Commit ea22b7a5 by Olli Etuaho Committed by Commit Bot

Constant fold array indexing and comparison

A virtual function to get the constant value of an AST node is added to TIntermTyped. This way a constant value can be retrieved conveniently from multiple different types of nodes. TIntermSymbol nodes pointing to a const variable can return the value associated with the variable, constructor nodes can build a constant value from their arguments, and indexing nodes can index into a constant array. This enables constant folding operations on constant arrays, while making sure that large amounts of data are not duplicated in the output shader. When folding an operation makes sense, the values of the arguments can be retrieved by using the new TIntermTyped::getConstantValue(). When folding an operation would result in duplicating data, the AST can just be left to be written out as is. For example, if the code contains a constant array of arrays, indexing into individual elements of the inner arrays can be folded, but indexing the top level array is left in place and not replaced with duplicated array literals. Constant folding is supported for indexing and comparisons of arrays. In case constant arrays are only referenced through foldable operations, the variable declarations will be pruned from the AST by the RemoveUnreferencedVariables step. BUG=angleproject:2298 TEST=angle_unittests Change-Id: I5b3be237b7e9fdba56aa9bf0a41b691f4d8f01eb Reviewed-on: https://chromium-review.googlesource.com/850973Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent ed1390f8
......@@ -45,14 +45,12 @@ void GetDeferredInitializers(TIntermDeclaration *declaration,
ASSERT(symbolNode);
TIntermTyped *expression = init->getRight();
if ((expression->getQualifier() != EvqConst ||
(expression->getAsConstantUnion() == nullptr &&
!expression->isConstructorWithOnlyConstantUnionParameters())))
if (expression->getQualifier() != EvqConst || !expression->hasConstantValue())
{
// For variables which are not constant, defer their real initialization until
// after we initialize uniforms.
// Deferral is done also in any cases where the variable has not been constant
// folded, since otherwise there's a chance that HLSL output will generate extra
// Deferral is done also in any cases where the variable can not be converted to a
// constant union, since otherwise there's a chance that HLSL output will generate extra
// statements from the initializer expression.
// Change const global to a regular global if its initialization is deferred.
......
......@@ -81,7 +81,7 @@ bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
}
ASSERT(constantExponent->getBasicType() == EbtFloat);
float exponentValue = constantExponent->getUnionArrayPointer()->getFConst();
float exponentValue = constantExponent->getConstantValue()->getFConst();
// Test 2: exponentValue is in the problematic range.
if (exponentValue < -5.0f || exponentValue > 9.0f)
......
......@@ -131,6 +131,12 @@ class TIntermTyped : public TIntermNode
virtual TIntermTyped *fold(TDiagnostics *diagnostics) { return this; }
// getConstantValue() returns the constant value that this node represents, if any. It
// should only be used after nodes have been replaced with their folded versions returned
// from fold(). hasConstantValue() returns true if getConstantValue() will return a value.
virtual bool hasConstantValue() const;
virtual const TConstantUnion *getConstantValue() const;
// True if executing the expression represented by this node affects state, like values of
// variables. False if the executing the expression only computes its return value without
// affecting state. May return true conservatively.
......@@ -161,8 +167,6 @@ class TIntermTyped : public TIntermNode
unsigned int getOutermostArraySize() const { return mType.getOutermostArraySize(); }
bool isConstructorWithOnlyConstantUnionParameters();
protected:
TType mType;
......@@ -241,6 +245,9 @@ class TIntermSymbol : public TIntermTyped
TIntermTyped *deepCopy() const override { return new TIntermSymbol(*this); }
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
bool hasSideEffects() const override { return false; }
const TSymbolUniqueId &uniqueId() const;
......@@ -302,9 +309,10 @@ class TIntermConstantUnion : public TIntermTyped
TIntermTyped *deepCopy() const override { return new TIntermConstantUnion(*this); }
bool hasSideEffects() const override { return false; }
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
const TConstantUnion *getUnionArrayPointer() const { return mUnionArrayPointer; }
bool hasSideEffects() const override { return false; }
int getIConst(size_t index) const
{
......@@ -334,15 +342,20 @@ class TIntermConstantUnion : public TIntermTyped
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *, TIntermNode *) override { return false; }
TConstantUnion *foldBinary(TOperator op,
TIntermConstantUnion *rightNode,
TDiagnostics *diagnostics,
const TSourceLoc &line);
const TConstantUnion *foldIndexing(int index);
TConstantUnion *foldUnaryNonComponentWise(TOperator op);
TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics);
static TConstantUnion *FoldAggregateConstructor(TIntermAggregate *aggregate);
static const TConstantUnion *FoldBinary(TOperator op,
const TConstantUnion *leftArray,
const TType &leftType,
const TConstantUnion *rightArray,
const TType &rightType,
TDiagnostics *diagnostics,
const TSourceLoc &line);
static const TConstantUnion *FoldIndexing(const TType &type,
const TConstantUnion *constArray,
int index);
static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate,
TDiagnostics *diagnostics);
......@@ -430,6 +443,9 @@ class TIntermBinary : public TIntermOperator
TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); }
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right);
static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
static TQualifier GetCommaQualifier(int shaderVersion,
......@@ -553,6 +569,9 @@ class TIntermAggregate : public TIntermOperator, public TIntermAggregateBase
TIntermAggregate *shallowCopy() const;
bool hasConstantValue() const override;
const TConstantUnion *getConstantValue() const override;
TIntermAggregate *getAsAggregate() override { return this; }
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
......
......@@ -441,7 +441,7 @@ void TOutputGLSLBase::visitSymbol(TIntermSymbol *node)
void TOutputGLSLBase::visitConstantUnion(TIntermConstantUnion *node)
{
writeConstantUnion(node->getType(), node->getUnionArrayPointer());
writeConstantUnion(node->getType(), node->getConstantValue());
}
bool TOutputGLSLBase::visitSwizzle(Visit visit, TIntermSwizzle *node)
......
......@@ -419,7 +419,8 @@ void OutputHLSL::header(TInfoSinkBase &out,
// Program linking depends on this exact format
varyings += "static " + InterpolationString(type.getQualifier()) + " " + TypeString(type) +
" " + Decorate(name) + ArrayString(type) + " = " + initializer(type) + ";\n";
" " + Decorate(name) + ArrayString(type) + " = " + zeroInitializer(type) +
";\n";
}
for (const auto &attribute : mReferencedAttributes)
......@@ -428,7 +429,7 @@ void OutputHLSL::header(TInfoSinkBase &out,
const TString &name = attribute.second->name();
attributes += "static " + TypeString(type) + " " + Decorate(name) + ArrayString(type) +
" = " + initializer(type) + ";\n";
" = " + zeroInitializer(type) + ";\n";
}
out << mStructureHLSL->structsHeader();
......@@ -501,7 +502,8 @@ void OutputHLSL::header(TInfoSinkBase &out,
const TType &variableType = outputVariable.second->getType();
out << "static " + TypeString(variableType) + " out_" + variableName +
ArrayString(variableType) + " = " + initializer(variableType) + ";\n";
ArrayString(variableType) + " = " + zeroInitializer(variableType) +
";\n";
}
}
else
......@@ -1156,23 +1158,22 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{
TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
ASSERT(symbolNode);
TIntermTyped *expression = node->getRight();
TIntermTyped *initializer = node->getRight();
// Global initializers must be constant at this point.
ASSERT(symbolNode->getQualifier() != EvqGlobal ||
canWriteAsHLSLLiteral(expression));
ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
// GLSL allows to write things like "float x = x;" where a new variable x is defined
// and the value of an existing variable x is assigned. HLSL uses C semantics (the
// new variable is created before the assignment is evaluated), so we need to
// convert
// this to "float t = x, x = t;".
if (writeSameSymbolInitializer(out, symbolNode, expression))
if (writeSameSymbolInitializer(out, symbolNode, initializer))
{
// Skip initializing the rest of the expression
return false;
}
else if (writeConstantInitialization(out, symbolNode, expression))
else if (writeConstantInitialization(out, symbolNode, initializer))
{
return false;
}
......@@ -1838,7 +1839,7 @@ bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
{
symbol->traverse(this);
out << ArrayString(symbol->getType());
out << " = " + initializer(symbol->getType());
out << " = " + zeroInitializer(symbol->getType());
}
else
{
......@@ -2240,7 +2241,7 @@ bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
{
TInfoSinkBase &out = getInfoSink();
writeConstantUnion(out, node->getType(), node->getUnionArrayPointer());
writeConstantUnion(out, node->getType(), node->getConstantValue());
}
bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
......@@ -2712,7 +2713,7 @@ TString OutputHLSL::argumentString(const TIntermSymbol *symbol)
return argString.str();
}
TString OutputHLSL::initializer(const TType &type)
TString OutputHLSL::zeroInitializer(const TType &type)
{
TString string;
......@@ -2763,6 +2764,8 @@ const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
const TType &type,
const TConstantUnion *const constUnion)
{
ASSERT(!type.isArray());
const TConstantUnion *constUnionIterated = constUnion;
const TStructure *structure = type.getStruct();
......@@ -2841,50 +2844,17 @@ bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
return false;
}
bool OutputHLSL::canWriteAsHLSLLiteral(TIntermTyped *expression)
{
// We support writing constant unions and constructors that only take constant unions as
// parameters as HLSL literals.
return !expression->getType().isArrayOfArrays() &&
(expression->getAsConstantUnion() ||
expression->isConstructorWithOnlyConstantUnionParameters());
}
bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
TIntermSymbol *symbolNode,
TIntermTyped *initializer)
{
if (canWriteAsHLSLLiteral(initializer))
if (initializer->hasConstantValue())
{
symbolNode->traverse(this);
ASSERT(!symbolNode->getType().isArrayOfArrays());
if (symbolNode->getType().isArray())
{
out << "[" << symbolNode->getType().getOutermostArraySize() << "]";
}
out << ArrayString(symbolNode->getType());
out << " = {";
if (initializer->getAsConstantUnion())
{
TIntermConstantUnion *nodeConst = initializer->getAsConstantUnion();
const TConstantUnion *constUnion = nodeConst->getUnionArrayPointer();
writeConstantUnionArray(out, constUnion, nodeConst->getType().getObjectSize());
}
else
{
TIntermAggregate *constructor = initializer->getAsAggregate();
ASSERT(constructor != nullptr);
for (TIntermNode *&node : *constructor->getSequence())
{
TIntermConstantUnion *nodeConst = node->getAsConstantUnion();
ASSERT(nodeConst);
const TConstantUnion *constUnion = nodeConst->getUnionArrayPointer();
writeConstantUnionArray(out, constUnion, nodeConst->getType().getObjectSize());
if (node != constructor->getSequence()->back())
{
out << ", ";
}
}
}
writeConstantUnionArray(out, initializer->getConstantValue(),
initializer->getType().getObjectSize());
out << "}";
return true;
}
......
......@@ -62,7 +62,7 @@ class OutputHLSL : public TIntermTraverser
const std::map<std::string, unsigned int> &getUniformBlockRegisterMap() const;
const std::map<std::string, unsigned int> &getUniformRegisterMap() const;
static TString initializer(const TType &type);
static TString zeroInitializer(const TType &type);
TInfoSinkBase &getInfoSink()
{
......@@ -70,8 +70,6 @@ class OutputHLSL : public TIntermTraverser
return *mInfoSinkStack.top();
}
static bool canWriteAsHLSLLiteral(TIntermTyped *expression);
protected:
void header(TInfoSinkBase &out,
const std::vector<MappedStruct> &std140Structs,
......
......@@ -273,7 +273,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
OutputTreeText(mOut, intermConstantUnion, mDepth + 1);
// The following code finds the field name from the constant union
const TConstantUnion *constantUnion = intermConstantUnion->getUnionArrayPointer();
const TConstantUnion *constantUnion = intermConstantUnion->getConstantValue();
const TStructure *structure = node->getLeft()->getType().getStruct();
const TInterfaceBlock *interfaceBlock = node->getLeft()->getType().getInterfaceBlock();
ASSERT(structure || interfaceBlock);
......@@ -557,10 +557,10 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node)
for (size_t i = 0; i < size; i++)
{
OutputTreeText(mOut, node, mDepth);
switch (node->getUnionArrayPointer()[i].getType())
switch (node->getConstantValue()[i].getType())
{
case EbtBool:
if (node->getUnionArrayPointer()[i].getBConst())
if (node->getConstantValue()[i].getBConst())
mOut << "true";
else
mOut << "false";
......@@ -571,20 +571,20 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node)
mOut << "\n";
break;
case EbtFloat:
mOut << node->getUnionArrayPointer()[i].getFConst();
mOut << node->getConstantValue()[i].getFConst();
mOut << " (const float)\n";
break;
case EbtInt:
mOut << node->getUnionArrayPointer()[i].getIConst();
mOut << node->getConstantValue()[i].getIConst();
mOut << " (const int)\n";
break;
case EbtUInt:
mOut << node->getUnionArrayPointer()[i].getUConst();
mOut << node->getConstantValue()[i].getUConst();
mOut << " (const uint)\n";
break;
case EbtYuvCscStandardEXT:
mOut << getYuvCscStandardEXTString(
node->getUnionArrayPointer()[i].getYuvCscStandardEXTConst());
node->getConstantValue()[i].getYuvCscStandardEXTConst());
mOut << " (const yuvCscStandardEXT)\n";
break;
default:
......
......@@ -1870,7 +1870,7 @@ TIntermTyped *TParseContext::parseVariableIdentifier(const TSourceLoc &location,
const TType &variableType = variable->getType();
TIntermTyped *node = nullptr;
if (variable->getConstPointer())
if (variable->getConstPointer() && variableType.canReplaceWithConstantUnion())
{
const TConstantUnion *constArray = variable->getConstPointer();
node = new TIntermConstantUnion(constArray, variableType);
......@@ -1989,24 +1989,13 @@ bool TParseContext::executeInitializer(const TSourceLoc &line,
return false;
}
// Save the constant folded value to the variable if possible. For example array
// initializers are not folded, since that way copying the array literal to multiple places
// in the shader is avoided.
// TODO(oetuaho@nvidia.com): Consider constant folding array initialization in cases where
// it would be beneficial.
if (initializer->getAsConstantUnion())
// Save the constant folded value to the variable if possible.
const TConstantUnion *constArray = initializer->getConstantValue();
if (constArray)
{
variable->shareConstPointer(initializer->getAsConstantUnion()->getUnionArrayPointer());
ASSERT(*initNode == nullptr);
return true;
}
else if (initializer->getAsSymbolNode())
{
const TVariable &var = initializer->getAsSymbolNode()->variable();
const TConstantUnion *constArray = var.getConstPointer();
if (constArray)
variable->shareConstPointer(constArray);
if (initializer->getType().canReplaceWithConstantUnion())
{
variable->shareConstPointer(constArray);
ASSERT(*initNode == nullptr);
return true;
}
......@@ -4071,7 +4060,7 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
TIntermBinary *node =
new TIntermBinary(EOpIndexDirect, baseExpression, indexExpression);
node->setLine(location);
return node->fold(mDiagnostics);
return expressionOrFoldedResult(node);
}
}
......@@ -4154,7 +4143,7 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre
TIntermBinary *node =
new TIntermBinary(EOpIndexDirectStruct, baseExpression, index);
node->setLine(dotLocation);
return node->fold(mDiagnostics);
return expressionOrFoldedResult(node);
}
else
{
......@@ -5357,9 +5346,7 @@ TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc);
// See if we can fold constants.
return node->fold(mDiagnostics);
return expressionOrFoldedResult(node);
}
TIntermTyped *TParseContext::addBinaryMath(TOperator op,
......@@ -5633,7 +5620,7 @@ void TParseContext::checkTextureOffsetConst(TIntermAggregate *functionCall)
{
ASSERT(offsetConstantUnion->getBasicType() == EbtInt);
size_t size = offsetConstantUnion->getType().getObjectSize();
const TConstantUnion *values = offsetConstantUnion->getUnionArrayPointer();
const TConstantUnion *values = offsetConstantUnion->getConstantValue();
int minOffsetValue = useTextureGatherOffsetConstraints ? mMinProgramTextureGatherOffset
: mMinProgramTexelOffset;
int maxOffsetValue = useTextureGatherOffsetConstraints ? mMaxProgramTextureGatherOffset
......
......@@ -236,7 +236,7 @@ bool RemoveUnreferencedVariablesTraverser::visitDeclaration(Visit visit, TInterm
// We can only remove variables that are not a part of the shader interface.
TQualifier qualifier = declarator->getQualifier();
if (qualifier != EvqTemporary && qualifier != EvqGlobal)
if (qualifier != EvqTemporary && qualifier != EvqGlobal && qualifier != EvqConst)
{
return true;
}
......
......@@ -55,7 +55,7 @@ bool SeparateArrayInitTraverser::visitDeclaration(Visit, TIntermDeclaration *nod
if (initNode != nullptr && initNode->getOp() == EOpInitialize)
{
TIntermTyped *initializer = initNode->getRight();
if (initializer->isArray() && !sh::OutputHLSL::canWriteAsHLSLLiteral(initializer))
if (initializer->isArray() && !initializer->hasConstantValue())
{
// We rely on that array declarations have been isolated to single declarations.
ASSERT(sequence->size() == 1);
......
......@@ -452,6 +452,27 @@ bool TType::isStructureContainingSamplers() const
return mStructure ? mStructure->containsSamplers() : false;
}
bool TType::canReplaceWithConstantUnion() const
{
if (isArray())
{
return false;
}
if (!mStructure)
{
return true;
}
if (isStructureContainingArrays())
{
return false;
}
if (getObjectSize() > 16)
{
return false;
}
return true;
}
//
// Recursively generate mangled names.
//
......
......@@ -305,6 +305,12 @@ class TType
bool isStructSpecifier() const { return mIsStructSpecifier; }
// Return true if variables of this type should be replaced with an inline constant value if
// such is available. False will be returned in cases where output doesn't support
// TIntermConstantUnion nodes of the type, or if the type contains a lot of fields and creating
// several copies of it in the output code is undesirable for performance.
bool canReplaceWithConstantUnion() const;
void createSamplerSymbols(const TString &namePrefix,
const TString &apiNamePrefix,
TVector<const TVariable *> *outputSymbols,
......
......@@ -1451,3 +1451,159 @@ TEST_F(ConstantFoldingTest, FoldTernaryInsideExpression)
ASSERT_TRUE(constantFoundInAST(3));
ASSERT_FALSE(symbolFoundInMain("u"));
}
// Fold indexing into an array constructor.
TEST_F(ConstantFoldingExpressionTest, FoldArrayConstructorIndexing)
{
const std::string &floatString = "(float[3](-1.0, 1.0, 2.0))[2]";
evaluateFloat(floatString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(1.0f));
ASSERT_TRUE(constantFoundInAST(2.0f));
}
// Fold indexing into an array of arrays constructor.
TEST_F(ConstantFoldingExpressionTest, FoldArrayOfArraysConstructorIndexing)
{
const std::string &floatString = "(float[2][2](float[2](-1.0, 1.0), float[2](2.0, 3.0)))[1][0]";
evaluateFloat(floatString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(1.0f));
ASSERT_FALSE(constantFoundInAST(3.0f));
ASSERT_TRUE(constantFoundInAST(2.0f));
}
// Fold indexing into a named constant array.
TEST_F(ConstantFoldingTest, FoldNamedArrayIndexing)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
const float[3] arr = float[3](-1.0, 1.0, 2.0);
out float my_FragColor;
void main()
{
my_FragColor = arr[1];
})";
compileAssumeSuccess(shaderString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(2.0f));
ASSERT_TRUE(constantFoundInAST(1.0f));
// The variable should be pruned out since after folding the indexing, there are no more
// references to it.
ASSERT_FALSE(symbolFoundInAST("arr"));
}
// Fold indexing into a named constant array of arrays.
TEST_F(ConstantFoldingTest, FoldNamedArrayOfArraysIndexing)
{
const std::string &shaderString =
R"(#version 310 es
precision highp float;
const float[2][2] arr = float[2][2](float[2](-1.0, 1.0), float[2](2.0, 3.0));
out float my_FragColor;
void main()
{
my_FragColor = arr[0][1];
})";
compileAssumeSuccess(shaderString);
ASSERT_FALSE(constantFoundInAST(-1.0f));
ASSERT_FALSE(constantFoundInAST(2.0f));
ASSERT_FALSE(constantFoundInAST(3.0f));
ASSERT_TRUE(constantFoundInAST(1.0f));
// The variable should be pruned out since after folding the indexing, there are no more
// references to it.
ASSERT_FALSE(symbolFoundInAST("arr"));
}
// Fold indexing into an array constructor where some of the arguments are constant and others are
// non-constant but without side effects.
TEST_F(ConstantFoldingTest, FoldArrayConstructorIndexingWithMixedArguments)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
uniform float u;
out float my_FragColor;
void main()
{
my_FragColor = float[2](u, 1.0)[1];
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(1.0f));
ASSERT_FALSE(constantFoundInAST(1));
ASSERT_FALSE(symbolFoundInMain("u"));
}
// Indexing into an array constructor where some of the arguments have side effects can't be folded.
TEST_F(ConstantFoldingTest, CantFoldArrayConstructorIndexingWithSideEffects)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
out float my_FragColor;
void main()
{
float sideEffectTarget = 0.0;
float f = float[3](sideEffectTarget = 1.0, 1.0, 2.0)[1];
my_FragColor = f + sideEffectTarget;
})";
compileAssumeSuccess(shaderString);
// All of the array constructor arguments should be present in the final AST.
ASSERT_TRUE(constantFoundInAST(1.0f));
ASSERT_TRUE(constantFoundInAST(2.0f));
}
// Fold comparing two array constructors.
TEST_F(ConstantFoldingTest, FoldArrayConstructorEquality)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
out float my_FragColor;
void main()
{
const bool b = (float[3](2.0, 1.0, -1.0) == float[3](2.0, 1.0, -1.0));
my_FragColor = b ? 3.0 : 4.0;
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(3.0f));
ASSERT_FALSE(constantFoundInAST(4.0f));
}
// Fold comparing two named constant arrays.
TEST_F(ConstantFoldingExpressionTest, FoldNamedArrayEquality)
{
const std::string &shaderString =
R"(#version 300 es
precision highp float;
const float[3] arrA = float[3](-1.0, 1.0, 2.0);
const float[3] arrB = float[3](-1.0, 1.0, 2.0);
out float my_FragColor;
void main()
{
const bool b = (arrA == arrB);
my_FragColor = b ? 3.0 : 4.0;
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(3.0f));
ASSERT_FALSE(constantFoundInAST(4.0f));
}
// Fold comparing two array of arrays constructors.
TEST_F(ConstantFoldingTest, FoldArrayOfArraysConstructorEquality)
{
const std::string &shaderString =
R"(#version 310 es
precision highp float;
out float my_FragColor;
void main()
{
const bool b = (float[2][2](float[2](-1.0, 1.0), float[2](2.0, 3.0)) ==
float[2][2](float[2](-1.0, 1.0), float[2](2.0, 1000.0)));
my_FragColor = b ? 4.0 : 5.0;
})";
compileAssumeSuccess(shaderString);
ASSERT_TRUE(constantFoundInAST(5.0f));
ASSERT_FALSE(constantFoundInAST(4.0f));
}
......@@ -5680,3 +5680,24 @@ TEST_F(FragmentShaderValidationTest, CommaReturnsNonConstant)
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
// Test that the result of indexing into an array constructor with some non-constant arguments is
// not a constant expression.
TEST_F(FragmentShaderValidationTest,
IndexingIntoArrayConstructorWithNonConstantArgumentsIsNotConstantExpression)
{
const std::string &shaderString =
R"(#version 310 es
precision highp float;
uniform float u;
out float my_FragColor;
void main()
{
const float f = float[2](u, 1.0)[1];
my_FragColor = f;
})";
if (compile(shaderString))
{
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
......@@ -17,39 +17,51 @@ using namespace sh;
void ConstantFoldingExpressionTest::evaluateFloat(const std::string &floatExpression)
{
// We first assign the expression into a const variable so we can also verify that it gets
// qualified as a constant expression. We then assign that constant expression into my_FragColor
// to make sure that the value is not pruned.
std::stringstream shaderStream;
shaderStream << "#version 310 es\n"
"precision mediump float;\n"
"out float my_FragColor;\n"
"void main()\n"
"{\n"
<< " my_FragColor = " << floatExpression << ";\n"
<< "}\n";
<< " const float f = " << floatExpression << ";\n"
<< " my_FragColor = f;\n"
"}\n";
compileAssumeSuccess(shaderStream.str());
}
void ConstantFoldingExpressionTest::evaluateInt(const std::string &intExpression)
{
// We first assign the expression into a const variable so we can also verify that it gets
// qualified as a constant expression. We then assign that constant expression into my_FragColor
// to make sure that the value is not pruned.
std::stringstream shaderStream;
shaderStream << "#version 310 es\n"
"precision mediump int;\n"
"out int my_FragColor;\n"
"void main()\n"
"{\n"
<< " my_FragColor = " << intExpression << ";\n"
<< "}\n";
<< " const int i = " << intExpression << ";\n"
<< " my_FragColor = i;\n"
"}\n";
compileAssumeSuccess(shaderStream.str());
}
void ConstantFoldingExpressionTest::evaluateUint(const std::string &uintExpression)
{
// We first assign the expression into a const variable so we can also verify that it gets
// qualified as a constant expression. We then assign that constant expression into my_FragColor
// to make sure that the value is not pruned.
std::stringstream shaderStream;
shaderStream << "#version 310 es\n"
"precision mediump int;\n"
"out uint my_FragColor;\n"
"void main()\n"
"{\n"
<< " my_FragColor = " << uintExpression << ";\n"
<< "}\n";
<< " const uint u = " << uintExpression << ";\n"
<< " my_FragColor = u;\n"
"}\n";
compileAssumeSuccess(shaderStream.str());
}
......@@ -56,7 +56,7 @@ class ConstantFinder : public TIntermTraverser
bool found = true;
for (size_t i = 0; i < mConstantVector.size(); i++)
{
if (!isEqual(node->getUnionArrayPointer()[i], mConstantVector[i]))
if (!isEqual(node->getConstantValue()[i], mConstantVector[i]))
{
found = false;
break;
......@@ -172,6 +172,11 @@ class ConstantFoldingTest : public ShaderCompileTreeTest
return finder.found();
}
bool symbolFoundInAST(const char *symbolName)
{
return FindSymbolNode(mASTRoot, TString(symbolName)) != nullptr;
}
bool symbolFoundInMain(const char *symbolName)
{
return FindSymbolNode(FindMain(mASTRoot), TString(symbolName)) != nullptr;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment