Commit 1d122789 by Olli Etuaho

Fix constructor constant folding

The previous solution for constant folding constructors was significantly overengineered and partially incorrect. Switch to a much simpler constructor folding function that does not use an AST traverser, but simply iterates over the constant folded parameters of the constructor and doesn't do any unnecessary checks. It also reuses some code for constant folding other built-in functions. This fixes issues with initializing constant matrices with only a single parameter. Instead of copying the first component of the constructor parameter all over the matrix, passing a vec4 or matrix argument now assigns the values correctly. BUG=angleproject:1193 TEST=angle_unittests, WebGL conformance tests Change-Id: I50b10721ea30cb15843fba892c1b1a211f1d72e5 Reviewed-on: https://chromium-review.googlesource.com/311191 Tryjob-Request: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com>
parent 7c3848e5
......@@ -139,7 +139,6 @@
'compiler/translator/glslang_tab.h',
'compiler/translator/intermOut.cpp',
'compiler/translator/length_limits.h',
'compiler/translator/parseConst.cpp',
'compiler/translator/timing/RestrictFragmentShaderTiming.cpp',
'compiler/translator/timing/RestrictFragmentShaderTiming.h',
'compiler/translator/timing/RestrictVertexShaderTiming.cpp',
......
......@@ -950,7 +950,11 @@ TIntermTyped *TIntermAggregate::fold(TInfoSink &infoSink)
return nullptr;
}
}
TConstantUnion *constArray = TIntermConstantUnion::FoldAggregateBuiltIn(this, infoSink);
TConstantUnion *constArray = nullptr;
if (isConstructor())
constArray = TIntermConstantUnion::FoldAggregateConstructor(this, infoSink);
else
constArray = TIntermConstantUnion::FoldAggregateBuiltIn(this, infoSink);
// Nodes may be constant folded without being qualified as constant.
TQualifier resultQualifier = areChildrenConstQualified() ? EvqConst : EvqTemporary;
......@@ -1973,6 +1977,106 @@ bool TIntermConstantUnion::foldFloatTypeUnary(const TConstantUnion &parameter, F
}
// static
TConstantUnion *TIntermConstantUnion::FoldAggregateConstructor(TIntermAggregate *aggregate,
TInfoSink &infoSink)
{
ASSERT(aggregate->getSequence()->size() > 0u);
size_t resultSize = aggregate->getType().getObjectSize();
TConstantUnion *resultArray = new TConstantUnion[resultSize];
TBasicType basicType = aggregate->getBasicType();
size_t resultIndex = 0u;
if (aggregate->getSequence()->size() == 1u)
{
TIntermNode *argument = aggregate->getSequence()->front();
TIntermConstantUnion *argumentConstant = argument->getAsConstantUnion();
const TConstantUnion *argumentUnionArray = argumentConstant->getUnionArrayPointer();
// Check the special case of constructing a matrix diagonal from a single scalar,
// or a vector from a single scalar.
if (argumentConstant->getType().getObjectSize() == 1u)
{
if (aggregate->isMatrix())
{
int resultCols = aggregate->getType().getCols();
int resultRows = aggregate->getType().getRows();
for (int col = 0; col < resultCols; ++col)
{
for (int row = 0; row < resultRows; ++row)
{
if (col == row)
{
resultArray[resultIndex].cast(basicType, argumentUnionArray[0]);
}
else
{
resultArray[resultIndex].setFConst(0.0f);
}
++resultIndex;
}
}
}
else
{
while (resultIndex < resultSize)
{
resultArray[resultIndex].cast(basicType, argumentUnionArray[0]);
++resultIndex;
}
}
ASSERT(resultIndex == resultSize);
return resultArray;
}
else if (aggregate->isMatrix() && argumentConstant->isMatrix())
{
// The special case of constructing a matrix from a matrix.
int argumentCols = argumentConstant->getType().getCols();
int argumentRows = argumentConstant->getType().getRows();
int resultCols = aggregate->getType().getCols();
int resultRows = aggregate->getType().getRows();
for (int col = 0; col < resultCols; ++col)
{
for (int row = 0; row < resultRows; ++row)
{
if (col < argumentCols && row < argumentRows)
{
resultArray[resultIndex].cast(basicType,
argumentUnionArray[col * argumentRows + row]);
}
else if (col == row)
{
resultArray[resultIndex].setFConst(1.0f);
}
else
{
resultArray[resultIndex].setFConst(0.0f);
}
++resultIndex;
}
}
ASSERT(resultIndex == resultSize);
return resultArray;
}
}
for (TIntermNode *&argument : *aggregate->getSequence())
{
TIntermConstantUnion *argumentConstant = argument->getAsConstantUnion();
size_t argumentSize = argumentConstant->getType().getObjectSize();
const TConstantUnion *argumentUnionArray = argumentConstant->getUnionArrayPointer();
for (size_t i = 0u; i < argumentSize; ++i)
{
if (resultIndex >= resultSize)
break;
resultArray[resultIndex].cast(basicType, argumentUnionArray[i]);
++resultIndex;
}
}
ASSERT(resultIndex == resultSize);
return resultArray;
}
// static
TConstantUnion *TIntermConstantUnion::FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink)
{
TOperator op = aggregate->getOp();
......
......@@ -345,6 +345,8 @@ class TIntermConstantUnion : public TIntermTyped
TConstantUnion *foldUnaryWithDifferentReturnType(TOperator op, TInfoSink &infoSink);
TConstantUnion *foldUnaryWithSameReturnType(TOperator op, TInfoSink &infoSink);
static TConstantUnion *FoldAggregateConstructor(TIntermAggregate *aggregate,
TInfoSink &infoSink);
static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink);
protected:
......
......@@ -469,33 +469,38 @@ TIntermTyped *TIntermediate::foldAggregateBuiltIn(TIntermAggregate *aggregate)
{
switch (aggregate->getOp())
{
case EOpAtan:
case EOpPow:
case EOpMod:
case EOpMin:
case EOpMax:
case EOpClamp:
case EOpMix:
case EOpStep:
case EOpSmoothStep:
case EOpMul:
case EOpOuterProduct:
case EOpLessThan:
case EOpLessThanEqual:
case EOpGreaterThan:
case EOpGreaterThanEqual:
case EOpVectorEqual:
case EOpVectorNotEqual:
case EOpDistance:
case EOpDot:
case EOpCross:
case EOpFaceForward:
case EOpReflect:
case EOpRefract:
return aggregate->fold(mInfoSink);
default:
// Constant folding not supported for the built-in.
return nullptr;
case EOpAtan:
case EOpPow:
case EOpMod:
case EOpMin:
case EOpMax:
case EOpClamp:
case EOpMix:
case EOpStep:
case EOpSmoothStep:
case EOpMul:
case EOpOuterProduct:
case EOpLessThan:
case EOpLessThanEqual:
case EOpGreaterThan:
case EOpGreaterThanEqual:
case EOpVectorEqual:
case EOpVectorNotEqual:
case EOpDistance:
case EOpDot:
case EOpCross:
case EOpFaceForward:
case EOpReflect:
case EOpRefract:
return aggregate->fold(mInfoSink);
default:
// TODO: Add support for folding array constructors
if (aggregate->isConstructor() && !aggregate->isArray())
{
return aggregate->fold(mInfoSink);
}
// Constant folding not supported for the built-in.
return nullptr;
}
return nullptr;
......
......@@ -54,9 +54,6 @@ class TIntermediate
int shaderVersion);
TIntermConstantUnion *addConstantUnion(
TConstantUnion *constantUnion, const TType &type, const TSourceLoc &line);
// TODO(zmo): Get rid of default value.
bool parseConstTree(const TSourceLoc &, TIntermNode *, TConstantUnion *,
TOperator, TType, bool singleConstantParam = false);
TIntermNode *addLoop(TLoopType, TIntermNode *, TIntermTyped *, TIntermTyped *,
TIntermNode *, const TSourceLoc &);
TIntermBranch *addBranch(TOperator, const TSourceLoc &);
......
......@@ -1369,17 +1369,6 @@ bool TParseContext::executeInitializer(const TSourceLoc &line,
return false;
}
bool TParseContext::areAllChildrenConstantFolded(TIntermAggregate *aggrNode)
{
ASSERT(aggrNode != nullptr);
for (TIntermNode *&node : *aggrNode->getSequence())
{
if (node->getAsConstantUnion() == nullptr)
return false;
}
return true;
}
TPublicType TParseContext::addFullySpecifiedType(TQualifier qualifier,
bool invariant,
TLayoutQualifier layoutQualifier,
......@@ -2318,7 +2307,7 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments,
type->setPrecision(constructor->getPrecision());
}
TIntermTyped *constConstructor = foldConstConstructor(constructor, *type);
TIntermTyped *constConstructor = intermediate.foldAggregateBuiltIn(constructor);
if (constConstructor)
{
return constConstructor;
......@@ -2327,33 +2316,6 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments,
return constructor;
}
TIntermTyped *TParseContext::foldConstConstructor(TIntermAggregate *aggrNode, const TType &type)
{
// TODO: Add support for folding array constructors
bool canBeFolded = areAllChildrenConstantFolded(aggrNode) && !type.isArray();
if (canBeFolded)
{
bool returnVal = false;
TConstantUnion *unionArray = new TConstantUnion[type.getObjectSize()];
if (aggrNode->getSequence()->size() == 1)
{
returnVal = intermediate.parseConstTree(aggrNode->getLine(), aggrNode, unionArray,
aggrNode->getOp(), type, true);
}
else
{
returnVal = intermediate.parseConstTree(aggrNode->getLine(), aggrNode, unionArray,
aggrNode->getOp(), type);
}
if (returnVal)
return 0;
return intermediate.addConstantUnion(unionArray, type, aggrNode->getLine());
}
return 0;
}
//
// This function returns the tree representation for the vector field(s) being accessed from contant
// vector.
......
......@@ -165,7 +165,6 @@ class TParseContext : angle::NonCopyable
void handlePragmaDirective(const TSourceLoc &loc, const char *name, const char *value, bool stdgl);
bool containsSampler(const TType &type);
bool areAllChildrenConstantFolded(TIntermAggregate *aggrNode);
const TFunction* findFunction(
const TSourceLoc &line, TFunction *pfnCall, int inputShaderVersion, bool *builtIn = 0);
bool executeInitializer(const TSourceLoc &line,
......@@ -247,7 +246,6 @@ class TParseContext : angle::NonCopyable
TOperator op,
TFunction *fnCall,
const TSourceLoc &line);
TIntermTyped *foldConstConstructor(TIntermAggregate *aggrNode, const TType &type);
TIntermTyped *addConstVectorNode(TVectorFields &fields,
TIntermConstantUnion *node,
const TSourceLoc &line,
......
//
// Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/translator/ParseContext.h"
//
// Use this class to carry along data from node to node in
// the traversal
//
class TConstTraverser : public TIntermTraverser
{
public:
TConstTraverser(TConstantUnion *cUnion, bool singleConstParam,
TOperator constructType, TInfoSink &sink, TType &t)
: TIntermTraverser(true, false, false),
error(false),
mIndex(0),
mUnionArray(cUnion),
mType(t),
mConstructorType(constructType),
mSingleConstantParam(singleConstParam),
mInfoSink(sink),
mSize(0),
mIsDiagonalMatrixInit(false),
mMatrixCols(0),
mMatrixRows(0)
{
}
bool error;
protected:
void visitSymbol(TIntermSymbol *) override;
void visitConstantUnion(TIntermConstantUnion *) override;
bool visitBinary(Visit visit, TIntermBinary *) override;
bool visitUnary(Visit visit, TIntermUnary *) override;
bool visitSelection(Visit visit, TIntermSelection *) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitLoop(Visit visit, TIntermLoop *) override;
bool visitBranch(Visit visit, TIntermBranch *) override;
size_t mIndex;
TConstantUnion *mUnionArray;
TType mType;
TOperator mConstructorType;
bool mSingleConstantParam;
TInfoSink &mInfoSink;
size_t mSize; // size of the constructor ( 4 for vec4)
bool mIsDiagonalMatrixInit;
int mMatrixCols; // columns of the matrix
int mMatrixRows; // rows of the matrix
};
//
// The rest of the file are the traversal functions. The last one
// is the one that starts the traversal.
//
// Return true from interior nodes to have the external traversal
// continue on to children. If you process children yourself,
// return false.
//
void TConstTraverser::visitSymbol(TIntermSymbol *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Symbol Node found in constant constructor");
return;
}
bool TConstTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
TQualifier qualifier = node->getType().getQualifier();
if (qualifier != EvqConst)
{
TString buf;
buf.append("'constructor' : assigning non-constant to ");
buf.append(mType.getCompleteString());
mInfoSink.info.message(EPrefixError, node->getLine(), buf.c_str());
error = true;
return false;
}
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Binary Node found in constant constructor");
return false;
}
bool TConstTraverser::visitUnary(Visit visit, TIntermUnary *node)
{
TString buf;
buf.append("'constructor' : assigning non-constant to ");
buf.append(mType.getCompleteString());
mInfoSink.info.message(EPrefixError, node->getLine(), buf.c_str());
error = true;
return false;
}
bool TConstTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
if (!node->isConstructor() && node->getOp() != EOpComma)
{
TString buf;
buf.append("'constructor' : assigning non-constant to ");
buf.append(mType.getCompleteString());
mInfoSink.info.message(EPrefixError, node->getLine(), buf.c_str());
error = true;
return false;
}
if (node->getSequence()->size() == 0)
{
error = true;
return false;
}
bool flag = node->getSequence()->size() == 1 &&
(*node->getSequence())[0]->getAsTyped()->getAsConstantUnion();
if (flag)
{
mSingleConstantParam = true;
mConstructorType = node->getOp();
mSize = node->getType().getObjectSize();
if (node->getType().isMatrix())
{
mIsDiagonalMatrixInit = true;
mMatrixCols = node->getType().getCols();
mMatrixRows = node->getType().getRows();
}
}
for (TIntermSequence::iterator p = node->getSequence()->begin();
p != node->getSequence()->end(); p++)
{
if (node->getOp() == EOpComma)
mIndex = 0;
(*p)->traverse(this);
}
if (flag)
{
mSingleConstantParam = false;
mConstructorType = EOpNull;
mSize = 0;
mIsDiagonalMatrixInit = false;
mMatrixCols = 0;
mMatrixRows = 0;
}
return false;
}
bool TConstTraverser::visitSelection(Visit visit, TIntermSelection *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Selection Node found in constant constructor");
error = true;
return false;
}
void TConstTraverser::visitConstantUnion(TIntermConstantUnion *node)
{
if (!node->getUnionArrayPointer())
{
// The constant was not initialized, this should already have been logged
ASSERT(mInfoSink.info.size() != 0);
return;
}
TConstantUnion *leftUnionArray = mUnionArray;
size_t instanceSize = mType.getObjectSize();
TBasicType basicType = mType.getBasicType();
if (mIndex >= instanceSize)
return;
if (!mSingleConstantParam)
{
size_t objectSize = node->getType().getObjectSize();
const TConstantUnion *rightUnionArray = node->getUnionArrayPointer();
for (size_t i=0; i < objectSize; i++)
{
if (mIndex >= instanceSize)
return;
leftUnionArray[mIndex].cast(basicType, rightUnionArray[i]);
mIndex++;
}
}
else
{
size_t totalSize = mIndex + mSize;
const TConstantUnion *rightUnionArray = node->getUnionArrayPointer();
if (!mIsDiagonalMatrixInit)
{
int count = 0;
for (size_t i = mIndex; i < totalSize; i++)
{
if (i >= instanceSize)
return;
leftUnionArray[i].cast(basicType, rightUnionArray[count]);
mIndex++;
if (node->getType().getObjectSize() > 1)
count++;
}
}
else
{
// for matrix diagonal constructors from a single scalar
for (int i = 0, col = 0; col < mMatrixCols; col++)
{
for (int row = 0; row < mMatrixRows; row++, i++)
{
if (col == row)
{
leftUnionArray[i].cast(basicType, rightUnionArray[0]);
}
else
{
leftUnionArray[i].setFConst(0.0f);
}
mIndex++;
}
}
}
}
}
bool TConstTraverser::visitLoop(Visit visit, TIntermLoop *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Loop Node found in constant constructor");
error = true;
return false;
}
bool TConstTraverser::visitBranch(Visit visit, TIntermBranch *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Branch Node found in constant constructor");
error = true;
return false;
}
//
// This function is the one to call externally to start the traversal.
// Individual functions can be initialized to 0 to skip processing of that
// type of node. It's children will still be processed.
//
bool TIntermediate::parseConstTree(
const TSourceLoc &line, TIntermNode *root, TConstantUnion *unionArray,
TOperator constructorType, TType t, bool singleConstantParam)
{
if (root == 0)
return false;
TConstTraverser it(unionArray, singleConstantParam, constructorType,
mInfoSink, t);
root->traverse(&it);
if (it.error)
return true;
else
return false;
}
......@@ -564,3 +564,136 @@ TEST_F(ConstantFoldingTest, FoldUnaryMinusOnUintLiteral)
compile(shaderString);
ASSERT_TRUE(constantFoundInAST(0xFFFFFFFFu));
}
// Test that constant mat2 initialization with a mat2 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingMat2)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(mat2(0.0, 1.0, 2.0, 3.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
0.0f, 1.0f,
2.0f, 3.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with an int parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingScalar)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(3);\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
3.0f, 0.0f,
0.0f, 3.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with a mix of parameters works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingMix)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(-1, vec2(0.0, 1.0), vec4(2.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
-1.0, 0.0f,
1.0f, 2.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with a mat3 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingMat3)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(mat3(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
0.0f, 1.0f,
3.0f, 4.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat4x3 initialization with a mat3x2 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat4x3ConstructorTakingMat3x2)
{
const std::string &shaderString =
"#version 300 es\n"
"precision mediump float;\n"
"uniform float mult;\n"
"out vec4 my_FragColor;\n"
"void main() {\n"
" const mat4x3 cm = mat4x3(mat3x2(1.0, 2.0,\n"
" 3.0, 4.0,\n"
" 5.0, 6.0));\n"
" mat4x3 m = cm * mult;\n"
" my_FragColor = vec4(m[0], m[1][0]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
1.0f, 2.0f, 0.0f,
3.0f, 4.0f, 0.0f,
5.0f, 6.0f, 1.0f,
0.0f, 0.0f, 0.0f
};
std::vector<float> result(outputElements, outputElements + 12);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with a vec4 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingVec4)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(vec4(0.0, 1.0, 2.0, 3.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
0.0f, 1.0f,
2.0f, 3.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment