Commit ae04e1e4 by Olli Etuaho Committed by Commit Bot

Fix scalarizing vec and mat constructor args

Scalarizing vec and mat constructor args can generate new statements in the parent block of the constructor. To preserve the correct execution order of expressions, scalarized vector and matrix constructors need to be first moved out from inside loop conditions and sequence operators. This is done whenever the compiler flag to scalarize args is on. BUG=chromium:772653 TEST=angle_unittests Change-Id: Id40f8d848a9d087e186ef2e680c8e4cd440221d9 Reviewed-on: https://chromium-review.googlesource.com/790412 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org>
parent f13cadd8
...@@ -516,20 +516,24 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root, ...@@ -516,20 +516,24 @@ bool TCompiler::checkAndSimplifyAST(TIntermBlock *root,
&symbolTable, shaderVersion); &symbolTable, shaderVersion);
} }
int simplifyScalarized = (compileOptions & SH_SCALARIZE_VEC_AND_MAT_CONSTRUCTOR_ARGS)
? IntermNodePatternMatcher::kScalarizedVecOrMatConstructor
: 0;
// Split multi declarations and remove calls to array length(). // Split multi declarations and remove calls to array length().
// Note that SimplifyLoopConditions needs to be run before any other AST transformations // Note that SimplifyLoopConditions needs to be run before any other AST transformations
// that may need to generate new statements from loop conditions or loop expressions. // that may need to generate new statements from loop conditions or loop expressions.
SimplifyLoopConditions( SimplifyLoopConditions(root,
root, IntermNodePatternMatcher::kMultiDeclaration |
IntermNodePatternMatcher::kMultiDeclaration | IntermNodePatternMatcher::kArrayLengthMethod, IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized,
&getSymbolTable(), getShaderVersion()); &getSymbolTable(), getShaderVersion());
// Note that separate declarations need to be run before other AST transformations that // Note that separate declarations need to be run before other AST transformations that
// generate new statements from expressions. // generate new statements from expressions.
SeparateDeclarations(root); SeparateDeclarations(root);
SplitSequenceOperator(root, IntermNodePatternMatcher::kArrayLengthMethod, &getSymbolTable(), SplitSequenceOperator(root, IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized,
getShaderVersion()); &getSymbolTable(), getShaderVersion());
RemoveArrayLengthMethod(root); RemoveArrayLengthMethod(root);
......
...@@ -15,6 +15,33 @@ ...@@ -15,6 +15,33 @@
namespace sh namespace sh
{ {
namespace
{
bool ContainsMatrixNode(const TIntermSequence &sequence)
{
for (size_t ii = 0; ii < sequence.size(); ++ii)
{
TIntermTyped *node = sequence[ii]->getAsTyped();
if (node && node->isMatrix())
return true;
}
return false;
}
bool ContainsVectorNode(const TIntermSequence &sequence)
{
for (size_t ii = 0; ii < sequence.size(); ++ii)
{
TIntermTyped *node = sequence[ii]->getAsTyped();
if (node && node->isVector())
return true;
}
return false;
}
} // anonymous namespace
IntermNodePatternMatcher::IntermNodePatternMatcher(const unsigned int mask) : mMask(mask) IntermNodePatternMatcher::IntermNodePatternMatcher(const unsigned int mask) : mMask(mask)
{ {
} }
...@@ -105,6 +132,20 @@ bool IntermNodePatternMatcher::match(TIntermAggregate *node, TIntermNode *parent ...@@ -105,6 +132,20 @@ bool IntermNodePatternMatcher::match(TIntermAggregate *node, TIntermNode *parent
} }
} }
} }
if ((mMask & kScalarizedVecOrMatConstructor) != 0)
{
if (node->getOp() == EOpConstruct)
{
if (node->getType().isVector() && ContainsMatrixNode(*(node->getSequence())))
{
return true;
}
else if (node->getType().isMatrix() && ContainsVectorNode(*(node->getSequence())))
{
return true;
}
}
}
return false; return false;
} }
......
...@@ -48,7 +48,11 @@ class IntermNodePatternMatcher ...@@ -48,7 +48,11 @@ class IntermNodePatternMatcher
kNamelessStructDeclaration = 0x0001 << 5, kNamelessStructDeclaration = 0x0001 << 5,
// Matches array length() method. // Matches array length() method.
kArrayLengthMethod = 0x0001 << 6 kArrayLengthMethod = 0x0001 << 6,
// Matches a vector or matrix constructor whose arguments are scalarized by the
// SH_SCALARIZE_VEC_OR_MAT_CONSTRUCTOR_ARGUMENTS workaround.
kScalarizedVecOrMatConstructor = 0x0001 << 7
}; };
IntermNodePatternMatcher(const unsigned int mask); IntermNodePatternMatcher(const unsigned int mask);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "angle_gl.h" #include "angle_gl.h"
#include "common/angleutils.h" #include "common/angleutils.h"
#include "compiler/translator/IntermNodePatternMatcher.h"
#include "compiler/translator/IntermNode_util.h" #include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h" #include "compiler/translator/IntermTraverse.h"
...@@ -24,28 +25,6 @@ namespace sh ...@@ -24,28 +25,6 @@ namespace sh
namespace namespace
{ {
bool ContainsMatrixNode(const TIntermSequence &sequence)
{
for (size_t ii = 0; ii < sequence.size(); ++ii)
{
TIntermTyped *node = sequence[ii]->getAsTyped();
if (node && node->isMatrix())
return true;
}
return false;
}
bool ContainsVectorNode(const TIntermSequence &sequence)
{
for (size_t ii = 0; ii < sequence.size(); ++ii)
{
TIntermTyped *node = sequence[ii]->getAsTyped();
if (node && node->isVector())
return true;
}
return false;
}
TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index) TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
{ {
return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index)); return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
...@@ -66,7 +45,8 @@ class ScalarizeArgsTraverser : public TIntermTraverser ...@@ -66,7 +45,8 @@ class ScalarizeArgsTraverser : public TIntermTraverser
TSymbolTable *symbolTable) TSymbolTable *symbolTable)
: TIntermTraverser(true, false, false, symbolTable), : TIntermTraverser(true, false, false, symbolTable),
mShaderType(shaderType), mShaderType(shaderType),
mFragmentPrecisionHigh(fragmentPrecisionHigh) mFragmentPrecisionHigh(fragmentPrecisionHigh),
mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
{ {
} }
...@@ -92,16 +72,24 @@ class ScalarizeArgsTraverser : public TIntermTraverser ...@@ -92,16 +72,24 @@ class ScalarizeArgsTraverser : public TIntermTraverser
sh::GLenum mShaderType; sh::GLenum mShaderType;
bool mFragmentPrecisionHigh; bool mFragmentPrecisionHigh;
IntermNodePatternMatcher mNodesToScalarize;
}; };
bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
if (visit == PreVisit && node->getOp() == EOpConstruct) ASSERT(visit == PreVisit);
if (mNodesToScalarize.match(node, getParentNode()))
{ {
if (node->getType().isVector() && ContainsMatrixNode(*(node->getSequence()))) if (node->getType().isVector())
{
scalarizeArgs(node, false, true); scalarizeArgs(node, false, true);
else if (node->getType().isMatrix() && ContainsVectorNode(*(node->getSequence()))) }
else
{
ASSERT(node->getType().isMatrix());
scalarizeArgs(node, true, false); scalarizeArgs(node, true, false);
}
} }
return true; return true;
} }
...@@ -134,55 +122,55 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -134,55 +122,55 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
ASSERT(!aggregate->isArray()); ASSERT(!aggregate->isArray());
int size = static_cast<int>(aggregate->getType().getObjectSize()); int size = static_cast<int>(aggregate->getType().getObjectSize());
TIntermSequence *sequence = aggregate->getSequence(); TIntermSequence *sequence = aggregate->getSequence();
TIntermSequence original(*sequence); TIntermSequence originalArgs(*sequence);
sequence->clear(); sequence->clear();
for (size_t ii = 0; ii < original.size(); ++ii) for (TIntermNode *originalArgNode : originalArgs)
{ {
ASSERT(size > 0); ASSERT(size > 0);
TIntermTyped *node = original[ii]->getAsTyped(); TIntermTyped *originalArg = originalArgNode->getAsTyped();
ASSERT(node); ASSERT(originalArg);
createTempVariable(node); createTempVariable(originalArg);
if (node->isScalar()) if (originalArg->isScalar())
{ {
sequence->push_back(createTempSymbol(node->getType())); sequence->push_back(createTempSymbol(originalArg->getType()));
size--; size--;
} }
else if (node->isVector()) else if (originalArg->isVector())
{ {
if (scalarizeVector) if (scalarizeVector)
{ {
int repeat = std::min(size, node->getNominalSize()); int repeat = std::min(size, originalArg->getNominalSize());
size -= repeat; size -= repeat;
for (int index = 0; index < repeat; ++index) for (int index = 0; index < repeat; ++index)
{ {
TIntermSymbol *symbolNode = createTempSymbol(node->getType()); TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index); TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index);
sequence->push_back(newNode); sequence->push_back(newNode);
} }
} }
else else
{ {
TIntermSymbol *symbolNode = createTempSymbol(node->getType()); TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
sequence->push_back(symbolNode); sequence->push_back(symbolNode);
size -= node->getNominalSize(); size -= originalArg->getNominalSize();
} }
} }
else else
{ {
ASSERT(node->isMatrix()); ASSERT(originalArg->isMatrix());
if (scalarizeMatrix) if (scalarizeMatrix)
{ {
int colIndex = 0, rowIndex = 0; int colIndex = 0, rowIndex = 0;
int repeat = std::min(size, node->getCols() * node->getRows()); int repeat = std::min(size, originalArg->getCols() * originalArg->getRows());
size -= repeat; size -= repeat;
while (repeat > 0) while (repeat > 0)
{ {
TIntermSymbol *symbolNode = createTempSymbol(node->getType()); TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
TIntermBinary *newNode = TIntermBinary *newNode =
ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex); ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
sequence->push_back(newNode); sequence->push_back(newNode);
rowIndex++; rowIndex++;
if (rowIndex >= node->getRows()) if (rowIndex >= originalArg->getRows())
{ {
rowIndex = 0; rowIndex = 0;
colIndex++; colIndex++;
...@@ -192,9 +180,9 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, ...@@ -192,9 +180,9 @@ void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
} }
else else
{ {
TIntermSymbol *symbolNode = createTempSymbol(node->getType()); TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
sequence->push_back(symbolNode); sequence->push_back(symbolNode);
size -= node->getCols() * node->getRows(); size -= originalArg->getCols() * originalArg->getRows();
} }
} }
} }
......
...@@ -82,6 +82,7 @@ ...@@ -82,6 +82,7 @@
'<(angle_path)/src/tests/compiler_tests/RemoveUnreferencedVariables_test.cpp', '<(angle_path)/src/tests/compiler_tests/RemoveUnreferencedVariables_test.cpp',
'<(angle_path)/src/tests/compiler_tests/RewriteDoWhile_test.cpp', '<(angle_path)/src/tests/compiler_tests/RewriteDoWhile_test.cpp',
'<(angle_path)/src/tests/compiler_tests/SamplerMultisample_test.cpp', '<(angle_path)/src/tests/compiler_tests/SamplerMultisample_test.cpp',
'<(angle_path)/src/tests/compiler_tests/ScalarizeVecAndMatConstructorArgs_test.cpp',
'<(angle_path)/src/tests/compiler_tests/ShaderExtension_test.cpp', '<(angle_path)/src/tests/compiler_tests/ShaderExtension_test.cpp',
'<(angle_path)/src/tests/compiler_tests/ShaderImage_test.cpp', '<(angle_path)/src/tests/compiler_tests/ShaderImage_test.cpp',
'<(angle_path)/src/tests/compiler_tests/ShaderValidation_test.cpp', '<(angle_path)/src/tests/compiler_tests/ShaderValidation_test.cpp',
......
//
// Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// ScalarizeVecAndMatConstructorArgs_test.cpp:
// Tests for scalarizing vector and matrix constructor args.
//
#include "GLSLANG/ShaderLang.h"
#include "angle_gl.h"
#include "gtest/gtest.h"
#include "tests/test_utils/compiler_test.h"
using namespace sh;
namespace
{
class ScalarizeVecAndMatConstructorArgsTest : public MatchOutputCodeTest
{
public:
ScalarizeVecAndMatConstructorArgsTest()
: MatchOutputCodeTest(GL_FRAGMENT_SHADER,
SH_SCALARIZE_VEC_AND_MAT_CONSTRUCTOR_ARGS,
SH_ESSL_OUTPUT)
{
}
};
// Verifies scalarizing matrix inside a vector constructor.
TEST_F(ScalarizeVecAndMatConstructorArgsTest, MatrixInVectorConstructor)
{
const std::string shaderString =
R"(
precision mediump float;
uniform mat2 umat2;
void main()
{
gl_FragColor = vec4(umat2);
})";
compile(shaderString);
std::vector<const char *> expectedStrings = {
"main()", " = _uumat2", "gl_FragColor = vec4(", "[0][0],", "[0][1],", "[1][0],", "[1][1])"};
EXPECT_TRUE(foundInCodeInOrder(expectedStrings));
}
// Verifies scalarizing a vector insized a matrix constructor.
TEST_F(ScalarizeVecAndMatConstructorArgsTest, VectorInMatrixConstructor)
{
const std::string shaderString =
R"(
precision mediump float;
uniform vec2 uvec2;
void main()
{
mat2 m = mat2(uvec2, uvec2);
gl_FragColor = vec4(m * uvec2, m * uvec2);
})";
compile(shaderString);
std::vector<const char *> expectedStrings = {
"main()", " = _uuvec2", "mat2(", "[0],", "[1],", "[0],", "[1])", "gl_FragColor = vec4("};
EXPECT_TRUE(foundInCodeInOrder(expectedStrings));
}
// Verifies that scalarizing vector and matrix constructor args inside a sequence operator preserves
// correct order of operations.
TEST_F(ScalarizeVecAndMatConstructorArgsTest, SequenceOperator)
{
const std::string shaderString =
R"(
precision mediump float;
uniform vec2 u;
void main()
{
vec2 v = u;
mat2 m = (v[0] += 1.0, mat2(v, v[1], -v[0]));
gl_FragColor = vec4(m[0], m[1]);
})";
compile(shaderString);
std::vector<const char *> expectedStrings = {"_uv[0] += 1.0", "-_uv[0]"};
EXPECT_TRUE(foundInCodeInOrder(expectedStrings));
}
// Verifies that scalarizing vector and matrix constructor args inside multiple declarations
// preserves the correct order of operations.
TEST_F(ScalarizeVecAndMatConstructorArgsTest, MultiDeclaration)
{
const std::string shaderString =
R"(
precision mediump float;
uniform vec2 u;
void main()
{
vec2 v = vec2(u[0]),
w = mat2(v, v) * u;
gl_FragColor = vec4(v, w);
})";
compile(shaderString);
std::vector<const char *> expectedStrings = {"vec2(_uu[0])", "mat2("};
EXPECT_TRUE(foundInCodeInOrder(expectedStrings));
}
} // anonymous namespace
...@@ -155,6 +155,29 @@ size_t MatchOutputCodeTest::findInCode(ShShaderOutput output, const char *string ...@@ -155,6 +155,29 @@ size_t MatchOutputCodeTest::findInCode(ShShaderOutput output, const char *string
return code->second.find(stringToFind); return code->second.find(stringToFind);
} }
bool MatchOutputCodeTest::foundInCodeInOrder(ShShaderOutput output,
std::vector<const char *> stringsToFind)
{
const auto code = mOutputCode.find(output);
EXPECT_NE(mOutputCode.end(), code);
if (code == mOutputCode.end())
{
return false;
}
size_t currentPos = 0;
for (const char *stringToFind : stringsToFind)
{
auto position = code->second.find(stringToFind, currentPos);
if (position == std::string::npos)
{
return false;
}
currentPos = position + strlen(stringToFind);
}
return true;
}
bool MatchOutputCodeTest::foundInCode(ShShaderOutput output, bool MatchOutputCodeTest::foundInCode(ShShaderOutput output,
const char *stringToFind, const char *stringToFind,
const int expectedOccurrences) const const int expectedOccurrences) const
...@@ -204,6 +227,18 @@ bool MatchOutputCodeTest::foundInCode(const char *stringToFind, const int expect ...@@ -204,6 +227,18 @@ bool MatchOutputCodeTest::foundInCode(const char *stringToFind, const int expect
return true; return true;
} }
bool MatchOutputCodeTest::foundInCodeInOrder(std::vector<const char *> stringsToFind)
{
for (auto &code : mOutputCode)
{
if (!foundInCodeInOrder(code.first, stringsToFind))
{
return false;
}
}
return true;
}
bool MatchOutputCodeTest::notFoundInCode(const char *stringToFind) const bool MatchOutputCodeTest::notFoundInCode(const char *stringToFind) const
{ {
for (auto &code : mOutputCode) for (auto &code : mOutputCode)
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#define TESTS_TEST_UTILS_COMPILER_TEST_H_ #define TESTS_TEST_UTILS_COMPILER_TEST_H_
#include <map> #include <map>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -68,6 +69,9 @@ class MatchOutputCodeTest : public testing::Test ...@@ -68,6 +69,9 @@ class MatchOutputCodeTest : public testing::Test
// source. If no matches are found, then string::npos is returned. // source. If no matches are found, then string::npos is returned.
size_t findInCode(ShShaderOutput output, const char *stringToFind) const; size_t findInCode(ShShaderOutput output, const char *stringToFind) const;
// Test that the strings are found in the specified output in the specified order.
bool foundInCodeInOrder(ShShaderOutput output, std::vector<const char *> stringsToFind);
// Test that the string occurs for exactly expectedOccurrences times // Test that the string occurs for exactly expectedOccurrences times
bool foundInCode(ShShaderOutput output, bool foundInCode(ShShaderOutput output,
const char *stringToFind, const char *stringToFind,
...@@ -79,6 +83,9 @@ class MatchOutputCodeTest : public testing::Test ...@@ -79,6 +83,9 @@ class MatchOutputCodeTest : public testing::Test
// Test that the string occurs for exactly expectedOccurrences times in all outputs // Test that the string occurs for exactly expectedOccurrences times in all outputs
bool foundInCode(const char *stringToFind, const int expectedOccurrences) const; bool foundInCode(const char *stringToFind, const int expectedOccurrences) const;
// Test that the strings are found in all outputs in the specified order.
bool foundInCodeInOrder(std::vector<const char *> stringsToFind);
// Test that the string is found in none of the outputs // Test that the string is found in none of the outputs
bool notFoundInCode(const char *stringToFind) const; bool notFoundInCode(const char *stringToFind) const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment