Commit b11e2483 by Jamie Madill

translator: Fix validation sometimes modifying builtins.

When validating some shaders with out-of-bounds array indexes, we would write the sanitized index into the global symbol table. We would then overwrite a wrong value for the builtin. This fixes the WebGL test extensions/webgl-draw-buffers-max-draw-buffers. Also mark const on as many uses ConstantUnion as we can. BUG=angleproject:993 Change-Id: I110efaf1b7b0158b08b704277e3bc2472437902c Reviewed-on: https://chromium-review.googlesource.com/268962Tested-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarKenneth Russell <kbr@chromium.org> Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org>
parent 6ba6eadc
...@@ -64,12 +64,12 @@ bool ValidateMultiplication(TOperator op, const TType &left, const TType &right) ...@@ -64,12 +64,12 @@ bool ValidateMultiplication(TOperator op, const TType &left, const TType &right)
} }
bool CompareStructure(const TType& leftNodeType, bool CompareStructure(const TType& leftNodeType,
TConstantUnion *rightUnionArray, const TConstantUnion *rightUnionArray,
TConstantUnion *leftUnionArray); const TConstantUnion *leftUnionArray);
bool CompareStruct(const TType &leftNodeType, bool CompareStruct(const TType &leftNodeType,
TConstantUnion *rightUnionArray, const TConstantUnion *rightUnionArray,
TConstantUnion *leftUnionArray) const TConstantUnion *leftUnionArray)
{ {
const TFieldList &fields = leftNodeType.getStruct()->fields(); const TFieldList &fields = leftNodeType.getStruct()->fields();
...@@ -102,8 +102,8 @@ bool CompareStruct(const TType &leftNodeType, ...@@ -102,8 +102,8 @@ bool CompareStruct(const TType &leftNodeType,
} }
bool CompareStructure(const TType &leftNodeType, bool CompareStructure(const TType &leftNodeType,
TConstantUnion *rightUnionArray, const TConstantUnion *rightUnionArray,
TConstantUnion *leftUnionArray) const TConstantUnion *leftUnionArray)
{ {
if (leftNodeType.isArray()) if (leftNodeType.isArray())
{ {
......
...@@ -270,7 +270,8 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -270,7 +270,8 @@ class TIntermConstantUnion : public TIntermTyped
virtual bool hasSideEffects() const { return false; } virtual bool hasSideEffects() const { return false; }
TConstantUnion *getUnionArrayPointer() const { return mUnionArrayPointer; } const TConstantUnion *getUnionArrayPointer() const { return mUnionArrayPointer; }
TConstantUnion *getUnionArrayPointer() { return mUnionArrayPointer; }
int getIConst(size_t index) const int getIConst(size_t index) const
{ {
...@@ -289,6 +290,12 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -289,6 +290,12 @@ class TIntermConstantUnion : public TIntermTyped
return mUnionArrayPointer ? mUnionArrayPointer[index].getBConst() : false; return mUnionArrayPointer ? mUnionArrayPointer[index].getBConst() : false;
} }
void replaceConstantUnion(TConstantUnion *safeConstantUnion)
{
// Previous union pointer freed on pool deallocation.
mUnionArrayPointer = safeConstantUnion;
}
virtual TIntermConstantUnion *getAsConstantUnion() { return this; } virtual TIntermConstantUnion *getAsConstantUnion() { return this; }
virtual void traverse(TIntermTraverser *); virtual void traverse(TIntermTraverser *);
virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; } virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
......
...@@ -361,9 +361,9 @@ TIntermCase *TIntermediate::addCase( ...@@ -361,9 +361,9 @@ TIntermCase *TIntermediate::addCase(
// //
TIntermConstantUnion *TIntermediate::addConstantUnion( TIntermConstantUnion *TIntermediate::addConstantUnion(
TConstantUnion *unionArrayPointer, const TType &t, const TSourceLoc &line) TConstantUnion *constantUnion, const TType &type, const TSourceLoc &line)
{ {
TIntermConstantUnion *node = new TIntermConstantUnion(unionArrayPointer, t); TIntermConstantUnion *node = new TIntermConstantUnion(constantUnion, type);
node->setLine(line); node->setLine(line);
return node; return node;
......
...@@ -49,7 +49,8 @@ class TIntermediate ...@@ -49,7 +49,8 @@ class TIntermediate
TIntermTyped *condition, const TSourceLoc &line); TIntermTyped *condition, const TSourceLoc &line);
TIntermTyped *addComma( TIntermTyped *addComma(
TIntermTyped *left, TIntermTyped *right, const TSourceLoc &); TIntermTyped *left, TIntermTyped *right, const TSourceLoc &);
TIntermConstantUnion *addConstantUnion(TConstantUnion *, const TType &, const TSourceLoc &); TIntermConstantUnion *addConstantUnion(
TConstantUnion *constantUnion, const TType &type, const TSourceLoc &line);
// TODO(zmo): Get rid of default value. // TODO(zmo): Get rid of default value.
bool parseConstTree(const TSourceLoc &, TIntermNode *, TConstantUnion *, bool parseConstTree(const TSourceLoc &, TIntermNode *, TConstantUnion *,
TOperator, TType, bool singleConstantParam = false); TOperator, TType, bool singleConstantParam = false);
......
...@@ -1814,7 +1814,7 @@ TIntermTyped* TParseContext::addConstVectorNode(TVectorFields& fields, TIntermTy ...@@ -1814,7 +1814,7 @@ TIntermTyped* TParseContext::addConstVectorNode(TVectorFields& fields, TIntermTy
TIntermTyped* typedNode; TIntermTyped* typedNode;
TIntermConstantUnion* tempConstantNode = node->getAsConstantUnion(); TIntermConstantUnion* tempConstantNode = node->getAsConstantUnion();
TConstantUnion *unionArray; const TConstantUnion *unionArray;
if (tempConstantNode) { if (tempConstantNode) {
unionArray = tempConstantNode->getUnionArrayPointer(); unionArray = tempConstantNode->getUnionArrayPointer();
...@@ -1868,7 +1868,7 @@ TIntermTyped* TParseContext::addConstMatrixNode(int index, TIntermTyped* node, c ...@@ -1868,7 +1868,7 @@ TIntermTyped* TParseContext::addConstMatrixNode(int index, TIntermTyped* node, c
} }
if (tempConstantNode) { if (tempConstantNode) {
TConstantUnion* unionArray = tempConstantNode->getUnionArrayPointer(); TConstantUnion *unionArray = tempConstantNode->getUnionArrayPointer();
int size = tempConstantNode->getType().getCols(); int size = tempConstantNode->getType().getCols();
typedNode = intermediate.addConstantUnion(&unionArray[size*index], tempConstantNode->getType(), line); typedNode = intermediate.addConstantUnion(&unionArray[size*index], tempConstantNode->getType(), line);
} else { } else {
...@@ -2202,6 +2202,8 @@ TIntermTyped* TParseContext::addIndexExpression(TIntermTyped *baseExpression, co ...@@ -2202,6 +2202,8 @@ TIntermTyped* TParseContext::addIndexExpression(TIntermTyped *baseExpression, co
} }
else else
{ {
int safeIndex = -1;
if (baseExpression->isArray()) if (baseExpression->isArray())
{ {
if (index >= baseExpression->getType().getArraySize()) if (index >= baseExpression->getType().getArraySize())
...@@ -2211,13 +2213,13 @@ TIntermTyped* TParseContext::addIndexExpression(TIntermTyped *baseExpression, co ...@@ -2211,13 +2213,13 @@ TIntermTyped* TParseContext::addIndexExpression(TIntermTyped *baseExpression, co
std::string extraInfo = extraInfoStream.str(); std::string extraInfo = extraInfoStream.str();
error(location, "", "[", extraInfo.c_str()); error(location, "", "[", extraInfo.c_str());
recover(); recover();
index = baseExpression->getType().getArraySize() - 1; safeIndex = baseExpression->getType().getArraySize() - 1;
} }
else if (baseExpression->getQualifier() == EvqFragData && index > 0 && !isExtensionEnabled("GL_EXT_draw_buffers")) else if (baseExpression->getQualifier() == EvqFragData && index > 0 && !isExtensionEnabled("GL_EXT_draw_buffers"))
{ {
error(location, "", "[", "array indexes for gl_FragData must be zero when GL_EXT_draw_buffers is disabled"); error(location, "", "[", "array indexes for gl_FragData must be zero when GL_EXT_draw_buffers is disabled");
recover(); recover();
index = 0; safeIndex = 0;
} }
} }
else if ((baseExpression->isVector() || baseExpression->isMatrix()) && baseExpression->getType().getNominalSize() <= index) else if ((baseExpression->isVector() || baseExpression->isMatrix()) && baseExpression->getType().getNominalSize() <= index)
...@@ -2227,10 +2229,18 @@ TIntermTyped* TParseContext::addIndexExpression(TIntermTyped *baseExpression, co ...@@ -2227,10 +2229,18 @@ TIntermTyped* TParseContext::addIndexExpression(TIntermTyped *baseExpression, co
std::string extraInfo = extraInfoStream.str(); std::string extraInfo = extraInfoStream.str();
error(location, "", "[", extraInfo.c_str()); error(location, "", "[", extraInfo.c_str());
recover(); recover();
index = baseExpression->getType().getNominalSize() - 1; safeIndex = baseExpression->getType().getNominalSize() - 1;
}
// Don't modify the data of the previous constant union, because it can point
// to builtins, like gl_MaxDrawBuffers. Instead use a new sanitized object.
if (safeIndex != -1)
{
TConstantUnion *safeConstantUnion = new TConstantUnion();
safeConstantUnion->setIConst(safeIndex);
indexConstantUnion->replaceConstantUnion(safeConstantUnion);
} }
indexConstantUnion->getUnionArrayPointer()->setIConst(index);
indexedExpression = intermediate.addIndex(EOpIndexDirect, baseExpression, indexExpression, location); indexedExpression = intermediate.addIndex(EOpIndexDirect, baseExpression, indexExpression, location);
} }
} }
......
...@@ -177,7 +177,7 @@ void TConstTraverser::visitConstantUnion(TIntermConstantUnion *node) ...@@ -177,7 +177,7 @@ void TConstTraverser::visitConstantUnion(TIntermConstantUnion *node)
if (!mSingleConstantParam) if (!mSingleConstantParam)
{ {
size_t objectSize = node->getType().getObjectSize(); size_t objectSize = node->getType().getObjectSize();
TConstantUnion *rightUnionArray = node->getUnionArrayPointer(); const TConstantUnion *rightUnionArray = node->getUnionArrayPointer();
for (size_t i=0; i < objectSize; i++) for (size_t i=0; i < objectSize; i++)
{ {
if (mIndex >= instanceSize) if (mIndex >= instanceSize)
...@@ -189,7 +189,7 @@ void TConstTraverser::visitConstantUnion(TIntermConstantUnion *node) ...@@ -189,7 +189,7 @@ void TConstTraverser::visitConstantUnion(TIntermConstantUnion *node)
else else
{ {
size_t totalSize = mIndex + mSize; size_t totalSize = mIndex + mSize;
TConstantUnion *rightUnionArray = node->getUnionArrayPointer(); const TConstantUnion *rightUnionArray = node->getUnionArrayPointer();
if (!mIsDiagonalMatrixInit) if (!mIsDiagonalMatrixInit)
{ {
int count = 0; int count = 0;
......
...@@ -81,7 +81,7 @@ class ShaderExtensionTest : public testing::Test ...@@ -81,7 +81,7 @@ class ShaderExtensionTest : public testing::Test
{ {
bool success = ShCompile(mCompiler, shaderStrings, stringCount, 0); bool success = ShCompile(mCompiler, shaderStrings, stringCount, 0);
const std::string& compileLog = ShGetInfoLog(mCompiler); const std::string& compileLog = ShGetInfoLog(mCompiler);
EXPECT_EQ(success, expectation) << compileLog; EXPECT_EQ(expectation, success) << compileLog;
} }
protected: protected:
...@@ -195,3 +195,23 @@ TEST_F(ShaderExtensionTest, TextureLODExtResetsInternalStates) ...@@ -195,3 +195,23 @@ TEST_F(ShaderExtensionTest, TextureLODExtResetsInternalStates)
TestShaderExtension(&TextureLODShdr[1], 1, false); TestShaderExtension(&TextureLODShdr[1], 1, false);
TestShaderExtension(TextureLODShdr, 2, true); TestShaderExtension(TextureLODShdr, 2, true);
} }
// Test a bug where we could modify the value of a builtin variable.
TEST_F(ShaderExtensionTest, BuiltinRewritingBug)
{
mResources.MaxDrawBuffers = 4;
mResources.EXT_draw_buffers = 1;
InitializeCompiler();
const std::string &shaderString =
"#extension GL_EXT_draw_buffers : require\n"
"precision mediump float;\n"
"void main() {\n"
" gl_FragData[gl_MaxDrawBuffers] = vec4(0.0);\n"
"}";
const char *shaderStrings[] = { shaderString.c_str() };
TestShaderExtension(shaderStrings, 1, false);
TestShaderExtension(shaderStrings, 1, false);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment