Commit 90892fbd by Olli Etuaho Committed by Commit Bot

Refine swizzle/indexing constant folding code

Fix constant folding of subscripting non-square matrices. Previously constant folding would offset the pointer into the matrix in multiples of the number of columns, when it should offset the pointer in multiples of the number of rows. Also change the MalformedShaderTest so that it only succeeds if vector swizzle is being checked correctly. Previously compilation would fail in the test either way because the shader code contained a call to an undefined function. Also refactor indexing checks and constant folding so that constant folding is done entirely separately from out-of-range checks. Bogus comments are removed from the constant folding functions. BUG=angleproject:1444 TEST=angle_unittests, angle_end2end_tests Change-Id: I7073b38f759e9b3635ee05947df4f6d8e23a39d5 Reviewed-on: https://chromium-review.googlesource.com/360112Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 313d9447
...@@ -2459,94 +2459,51 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments, ...@@ -2459,94 +2459,51 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments,
return constructor; return constructor;
} }
// // This function returns vector field(s) being accessed from a constant vector.
// This function returns the tree representation for the vector field(s) being accessed from contant TIntermConstantUnion *TParseContext::foldVectorSwizzle(TVectorFields &fields,
// vector. TIntermConstantUnion *baseNode,
// If only one component of vector is accessed (v.x or v[0] where v is a contant vector), then a const TSourceLoc &location)
// contant node is returned, else an aggregate node is returned (for v.xy). The input to this
// function could either be the symbol node or it could be the intermediate tree representation of
// accessing fields in a constant structure or column of a constant matrix.
//
TIntermTyped *TParseContext::addConstVectorNode(TVectorFields &fields,
TIntermConstantUnion *node,
const TSourceLoc &line,
bool outOfRangeIndexIsError)
{ {
const TConstantUnion *unionArray = node->getUnionArrayPointer(); const TConstantUnion *unionArray = baseNode->getUnionArrayPointer();
ASSERT(unionArray); ASSERT(unionArray);
TConstantUnion *constArray = new TConstantUnion[fields.num]; TConstantUnion *constArray = new TConstantUnion[fields.num];
const auto &type = node->getType(); const auto &type = baseNode->getType();
for (int i = 0; i < fields.num; i++) for (int i = 0; i < fields.num; i++)
{ {
if (fields.offsets[i] >= type.getNominalSize()) // Out-of-range indices should already be checked.
{ ASSERT(fields.offsets[i] < type.getNominalSize());
std::stringstream extraInfoStream;
extraInfoStream << "vector field selection out of range '" << fields.offsets[i] << "'";
std::string extraInfo = extraInfoStream.str();
outOfRangeError(outOfRangeIndexIsError, line, "", "[", extraInfo.c_str());
fields.offsets[i] = type.getNominalSize() - 1;
}
constArray[i] = unionArray[fields.offsets[i]]; constArray[i] = unionArray[fields.offsets[i]];
} }
return intermediate.addConstantUnion(constArray, type, line); return intermediate.addConstantUnion(constArray, type, location);
} }
// // This function returns the column vector being accessed from a constant matrix.
// This function returns the column being accessed from a constant matrix. The values are retrieved TIntermConstantUnion *TParseContext::foldMatrixSubscript(int index,
// from the symbol table and parse-tree is built for a vector (each column of a matrix is a vector). TIntermConstantUnion *baseNode,
// The input to the function could either be a symbol node (m[0] where m is a constant matrix)that const TSourceLoc &location)
// represents a constant matrix or it could be the tree representation of the constant matrix
// (s.m1[0] where s is a constant structure)
//
TIntermTyped *TParseContext::addConstMatrixNode(int index,
TIntermConstantUnion *node,
const TSourceLoc &line,
bool outOfRangeIndexIsError)
{ {
if (index >= node->getType().getCols()) ASSERT(index < baseNode->getType().getCols());
{
std::stringstream extraInfoStream;
extraInfoStream << "matrix field selection out of range '" << index << "'";
std::string extraInfo = extraInfoStream.str();
outOfRangeError(outOfRangeIndexIsError, line, "", "[", extraInfo.c_str());
index = node->getType().getCols() - 1;
}
const TConstantUnion *unionArray = node->getUnionArrayPointer(); const TConstantUnion *unionArray = baseNode->getUnionArrayPointer();
int size = node->getType().getCols(); int size = baseNode->getType().getRows();
return intermediate.addConstantUnion(&unionArray[size * index], node->getType(), line); return intermediate.addConstantUnion(&unionArray[size * index], baseNode->getType(), location);
} }
// // This function returns an element of an array accessed from a constant array.
// This function returns an element of an array accessed from a constant array. The values are TIntermConstantUnion *TParseContext::foldArraySubscript(int index,
// retrieved from the symbol table and parse-tree is built for the type of the element. The input TIntermConstantUnion *baseNode,
// to the function could either be a symbol node (a[0] where a is a constant array)that represents a const TSourceLoc &location)
// constant array or it could be the tree representation of the constant array (s.a1[0] where s is a
// constant structure)
//
TIntermTyped *TParseContext::addConstArrayNode(int index,
TIntermConstantUnion *node,
const TSourceLoc &line,
bool outOfRangeIndexIsError)
{ {
TType arrayElementType = node->getType(); ASSERT(index < baseNode->getArraySize());
arrayElementType.clearArrayness();
if (index >= node->getType().getArraySize()) TType arrayElementType = baseNode->getType();
{ arrayElementType.clearArrayness();
std::stringstream extraInfoStream;
extraInfoStream << "array field selection out of range '" << index << "'";
std::string extraInfo = extraInfoStream.str();
outOfRangeError(outOfRangeIndexIsError, line, "", "[", extraInfo.c_str());
index = node->getType().getArraySize() - 1;
}
size_t arrayElementSize = arrayElementType.getObjectSize(); size_t arrayElementSize = arrayElementType.getObjectSize();
const TConstantUnion *unionArray = node->getUnionArrayPointer(); const TConstantUnion *unionArray = baseNode->getUnionArrayPointer();
return intermediate.addConstantUnion(&unionArray[arrayElementSize * index], node->getType(), return intermediate.addConstantUnion(&unionArray[arrayElementSize * index], baseNode->getType(),
line); location);
} }
// //
...@@ -2875,38 +2832,47 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression, ...@@ -2875,38 +2832,47 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
// correct range. // correct range.
bool outOfRangeIndexIsError = indexExpression->getQualifier() == EvqConst; bool outOfRangeIndexIsError = indexExpression->getQualifier() == EvqConst;
int index = indexConstantUnion->getIConst(0); int index = indexConstantUnion->getIConst(0);
if (index < 0) if (!baseExpression->isArray())
{ {
std::stringstream infoStream; // Array checks are done later because a different error message might be generated
infoStream << index; // based on the index in some cases.
std::string info = infoStream.str(); if (baseExpression->isVector())
outOfRangeError(outOfRangeIndexIsError, location, "negative index", info.c_str()); {
index = 0; index = checkIndexOutOfRange(outOfRangeIndexIsError, location, index,
baseExpression->getType().getNominalSize(),
"vector field selection out of range", "[]");
}
else if (baseExpression->isMatrix())
{
index = checkIndexOutOfRange(outOfRangeIndexIsError, location, index,
baseExpression->getType().getCols(),
"matrix field selection out of range", "[]");
}
} }
TIntermConstantUnion *baseConstantUnion = baseExpression->getAsConstantUnion(); TIntermConstantUnion *baseConstantUnion = baseExpression->getAsConstantUnion();
if (baseConstantUnion) if (baseConstantUnion)
{ {
if (baseExpression->isArray()) if (baseExpression->isArray())
{ {
// constant folding for array indexing index = checkIndexOutOfRange(outOfRangeIndexIsError, location, index,
indexedExpression = baseExpression->getArraySize(),
addConstArrayNode(index, baseConstantUnion, location, outOfRangeIndexIsError); "array index out of range", "[]");
// Constant folding for array indexing.
indexedExpression = foldArraySubscript(index, baseConstantUnion, location);
} }
else if (baseExpression->isVector()) else if (baseExpression->isVector())
{ {
// constant folding for vector indexing // Constant folding for vector indexing - reusing vector swizzle folding.
TVectorFields fields; TVectorFields fields;
fields.num = 1; fields.num = 1;
fields.offsets[0] = fields.offsets[0] = index;
index; // need to do it this way because v.xy sends fields integer array indexedExpression = foldVectorSwizzle(fields, baseConstantUnion, location);
indexedExpression =
addConstVectorNode(fields, baseConstantUnion, location, outOfRangeIndexIsError);
} }
else if (baseExpression->isMatrix()) else if (baseExpression->isMatrix())
{ {
// constant folding for matrix indexing // Constant folding for matrix indexing.
indexedExpression = indexedExpression = foldMatrixSubscript(index, baseConstantUnion, location);
addConstMatrixNode(index, baseConstantUnion, location, outOfRangeIndexIsError);
} }
} }
else else
...@@ -2937,24 +2903,13 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression, ...@@ -2937,24 +2903,13 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
} }
} }
// Only do generic out-of-range check if similar error hasn't already been reported. // Only do generic out-of-range check if similar error hasn't already been reported.
if (safeIndex < 0 && index >= baseExpression->getType().getArraySize()) if (safeIndex < 0)
{ {
std::stringstream extraInfoStream; safeIndex = checkIndexOutOfRange(outOfRangeIndexIsError, location, index,
extraInfoStream << "array index out of range '" << index << "'"; baseExpression->getArraySize(),
std::string extraInfo = extraInfoStream.str(); "array index out of range", "[]");
outOfRangeError(outOfRangeIndexIsError, location, "", "[", extraInfo.c_str());
safeIndex = baseExpression->getType().getArraySize() - 1;
} }
} }
else if ((baseExpression->isVector() || baseExpression->isMatrix()) &&
baseExpression->getType().getNominalSize() <= index)
{
std::stringstream extraInfoStream;
extraInfoStream << "field selection out of range '" << index << "'";
std::string extraInfo = extraInfoStream.str();
outOfRangeError(outOfRangeIndexIsError, location, "", "[", extraInfo.c_str());
safeIndex = baseExpression->getType().getNominalSize() - 1;
}
// Data of constant unions can't be changed, because it may be shared with other // Data of constant unions can't be changed, because it may be shared with other
// constant unions or even builtins, like gl_MaxDrawBuffers. Instead use a new // constant unions or even builtins, like gl_MaxDrawBuffers. Instead use a new
...@@ -3018,6 +2973,31 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression, ...@@ -3018,6 +2973,31 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
return indexedExpression; return indexedExpression;
} }
int TParseContext::checkIndexOutOfRange(bool outOfRangeIndexIsError,
const TSourceLoc &location,
int index,
int arraySize,
const char *reason,
const char *token)
{
if (index >= arraySize || index < 0)
{
std::stringstream extraInfoStream;
extraInfoStream << "'" << index << "'";
std::string extraInfo = extraInfoStream.str();
outOfRangeError(outOfRangeIndexIsError, location, reason, token, extraInfo.c_str());
if (index < 0)
{
return 0;
}
else
{
return arraySize - 1;
}
}
return index;
}
TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpression, TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpression,
const TSourceLoc &dotLocation, const TSourceLoc &dotLocation,
const TString &fieldString, const TString &fieldString,
...@@ -3045,8 +3025,8 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre ...@@ -3045,8 +3025,8 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre
if (baseExpression->getAsConstantUnion()) if (baseExpression->getAsConstantUnion())
{ {
// constant folding for vector fields // constant folding for vector fields
indexedExpression = addConstVectorNode(fields, baseExpression->getAsConstantUnion(), indexedExpression =
fieldLocation, true); foldVectorSwizzle(fields, baseExpression->getAsConstantUnion(), fieldLocation);
} }
else else
{ {
......
...@@ -253,18 +253,7 @@ class TParseContext : angle::NonCopyable ...@@ -253,18 +253,7 @@ class TParseContext : angle::NonCopyable
TOperator op, TOperator op,
TFunction *fnCall, TFunction *fnCall,
const TSourceLoc &line); const TSourceLoc &line);
TIntermTyped *addConstVectorNode(TVectorFields &fields,
TIntermConstantUnion *node,
const TSourceLoc &line,
bool outOfRangeIndexIsError);
TIntermTyped *addConstMatrixNode(int index,
TIntermConstantUnion *node,
const TSourceLoc &line,
bool outOfRangeIndexIsError);
TIntermTyped *addConstArrayNode(int index,
TIntermConstantUnion *node,
const TSourceLoc &line,
bool outOfRangeIndexIsError);
TIntermTyped *addConstStruct( TIntermTyped *addConstStruct(
const TString &identifier, TIntermTyped *node, const TSourceLoc& line); const TString &identifier, TIntermTyped *node, const TSourceLoc& line);
TIntermTyped *addIndexExpression(TIntermTyped *baseExpression, TIntermTyped *addIndexExpression(TIntermTyped *baseExpression,
...@@ -342,6 +331,26 @@ class TParseContext : angle::NonCopyable ...@@ -342,6 +331,26 @@ class TParseContext : angle::NonCopyable
TSymbolTable &symbolTable; // symbol table that goes with the language currently being parsed TSymbolTable &symbolTable; // symbol table that goes with the language currently being parsed
private: private:
// Returns a clamped index.
int checkIndexOutOfRange(bool outOfRangeIndexIsError,
const TSourceLoc &location,
int index,
int arraySize,
const char *reason,
const char *token);
// Constant folding for element access. Note that the returned node does not have the correct
// type - it is expected to be fixed later.
TIntermConstantUnion *foldVectorSwizzle(TVectorFields &fields,
TIntermConstantUnion *baseNode,
const TSourceLoc &location);
TIntermConstantUnion *foldMatrixSubscript(int index,
TIntermConstantUnion *baseNode,
const TSourceLoc &location);
TIntermConstantUnion *foldArraySubscript(int index,
TIntermConstantUnion *baseNode,
const TSourceLoc &location);
bool declareVariable(const TSourceLoc &line, const TString &identifier, const TType &type, TVariable **variable); bool declareVariable(const TSourceLoc &line, const TString &identifier, const TType &type, TVariable **variable);
bool nonInitErrorCheck(const TSourceLoc &line, const TString &identifier, TPublicType *type); bool nonInitErrorCheck(const TSourceLoc &line, const TString &identifier, TPublicType *type);
......
...@@ -744,3 +744,20 @@ TEST_F(ConstantFoldingTest, FoldNestedIdenticalStructEqualityComparison) ...@@ -744,3 +744,20 @@ TEST_F(ConstantFoldingTest, FoldNestedIdenticalStructEqualityComparison)
compile(shaderString); compile(shaderString);
ASSERT_TRUE(constantFoundInAST(1.0f)); ASSERT_TRUE(constantFoundInAST(1.0f));
} }
// Test that right elements are chosen from non-square matrix
TEST_F(ConstantFoldingTest, FoldNonSquareMatrixIndexing)
{
const std::string &shaderString =
"#version 300 es\n"
"precision mediump float;\n"
"out vec4 my_FragColor;\n"
"void main()\n"
"{\n"
" my_FragColor = mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)[1];\n"
"}\n";
compile(shaderString);
float outputElements[] = {4.0f, 5.0f, 6.0f, 7.0f};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
...@@ -1611,12 +1611,12 @@ TEST_F(MalformedShaderTest, CompoundMultiplyMatrixValidNonSquareDimensions) ...@@ -1611,12 +1611,12 @@ TEST_F(MalformedShaderTest, CompoundMultiplyMatrixValidNonSquareDimensions)
} }
} }
// Covers a bug where we would set the incorrect result size on an out-of-bounds vector sizzle. // Covers a bug where we would set the incorrect result size on an out-of-bounds vector swizzle.
TEST_F(MalformedShaderTest, OutOfBoundsVectorSwizzle) TEST_F(MalformedShaderTest, OutOfBoundsVectorSwizzle)
{ {
const std::string &shaderString = const std::string &shaderString =
"void main() {\n" "void main() {\n"
" vec2(0).qq * a(b);\n" " vec2(0).qq;\n"
"}\n"; "}\n";
if (compile(shaderString)) if (compile(shaderString))
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment