Commit 1d122789 by Olli Etuaho

Fix constructor constant folding

The previous solution for constant folding constructors was significantly overengineered and partially incorrect. Switch to a much simpler constructor folding function that does not use an AST traverser, but simply iterates over the constant folded parameters of the constructor and doesn't do any unnecessary checks. It also reuses some code for constant folding other built-in functions. This fixes issues with initializing constant matrices with only a single parameter. Instead of copying the first component of the constructor parameter all over the matrix, passing a vec4 or matrix argument now assigns the values correctly. BUG=angleproject:1193 TEST=angle_unittests, WebGL conformance tests Change-Id: I50b10721ea30cb15843fba892c1b1a211f1d72e5 Reviewed-on: https://chromium-review.googlesource.com/311191 Tryjob-Request: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com>
parent 7c3848e5
...@@ -139,7 +139,6 @@ ...@@ -139,7 +139,6 @@
'compiler/translator/glslang_tab.h', 'compiler/translator/glslang_tab.h',
'compiler/translator/intermOut.cpp', 'compiler/translator/intermOut.cpp',
'compiler/translator/length_limits.h', 'compiler/translator/length_limits.h',
'compiler/translator/parseConst.cpp',
'compiler/translator/timing/RestrictFragmentShaderTiming.cpp', 'compiler/translator/timing/RestrictFragmentShaderTiming.cpp',
'compiler/translator/timing/RestrictFragmentShaderTiming.h', 'compiler/translator/timing/RestrictFragmentShaderTiming.h',
'compiler/translator/timing/RestrictVertexShaderTiming.cpp', 'compiler/translator/timing/RestrictVertexShaderTiming.cpp',
......
...@@ -950,7 +950,11 @@ TIntermTyped *TIntermAggregate::fold(TInfoSink &infoSink) ...@@ -950,7 +950,11 @@ TIntermTyped *TIntermAggregate::fold(TInfoSink &infoSink)
return nullptr; return nullptr;
} }
} }
TConstantUnion *constArray = TIntermConstantUnion::FoldAggregateBuiltIn(this, infoSink); TConstantUnion *constArray = nullptr;
if (isConstructor())
constArray = TIntermConstantUnion::FoldAggregateConstructor(this, infoSink);
else
constArray = TIntermConstantUnion::FoldAggregateBuiltIn(this, infoSink);
// Nodes may be constant folded without being qualified as constant. // Nodes may be constant folded without being qualified as constant.
TQualifier resultQualifier = areChildrenConstQualified() ? EvqConst : EvqTemporary; TQualifier resultQualifier = areChildrenConstQualified() ? EvqConst : EvqTemporary;
...@@ -1973,6 +1977,106 @@ bool TIntermConstantUnion::foldFloatTypeUnary(const TConstantUnion &parameter, F ...@@ -1973,6 +1977,106 @@ bool TIntermConstantUnion::foldFloatTypeUnary(const TConstantUnion &parameter, F
} }
// static // static
TConstantUnion *TIntermConstantUnion::FoldAggregateConstructor(TIntermAggregate *aggregate,
TInfoSink &infoSink)
{
ASSERT(aggregate->getSequence()->size() > 0u);
size_t resultSize = aggregate->getType().getObjectSize();
TConstantUnion *resultArray = new TConstantUnion[resultSize];
TBasicType basicType = aggregate->getBasicType();
size_t resultIndex = 0u;
if (aggregate->getSequence()->size() == 1u)
{
TIntermNode *argument = aggregate->getSequence()->front();
TIntermConstantUnion *argumentConstant = argument->getAsConstantUnion();
const TConstantUnion *argumentUnionArray = argumentConstant->getUnionArrayPointer();
// Check the special case of constructing a matrix diagonal from a single scalar,
// or a vector from a single scalar.
if (argumentConstant->getType().getObjectSize() == 1u)
{
if (aggregate->isMatrix())
{
int resultCols = aggregate->getType().getCols();
int resultRows = aggregate->getType().getRows();
for (int col = 0; col < resultCols; ++col)
{
for (int row = 0; row < resultRows; ++row)
{
if (col == row)
{
resultArray[resultIndex].cast(basicType, argumentUnionArray[0]);
}
else
{
resultArray[resultIndex].setFConst(0.0f);
}
++resultIndex;
}
}
}
else
{
while (resultIndex < resultSize)
{
resultArray[resultIndex].cast(basicType, argumentUnionArray[0]);
++resultIndex;
}
}
ASSERT(resultIndex == resultSize);
return resultArray;
}
else if (aggregate->isMatrix() && argumentConstant->isMatrix())
{
// The special case of constructing a matrix from a matrix.
int argumentCols = argumentConstant->getType().getCols();
int argumentRows = argumentConstant->getType().getRows();
int resultCols = aggregate->getType().getCols();
int resultRows = aggregate->getType().getRows();
for (int col = 0; col < resultCols; ++col)
{
for (int row = 0; row < resultRows; ++row)
{
if (col < argumentCols && row < argumentRows)
{
resultArray[resultIndex].cast(basicType,
argumentUnionArray[col * argumentRows + row]);
}
else if (col == row)
{
resultArray[resultIndex].setFConst(1.0f);
}
else
{
resultArray[resultIndex].setFConst(0.0f);
}
++resultIndex;
}
}
ASSERT(resultIndex == resultSize);
return resultArray;
}
}
for (TIntermNode *&argument : *aggregate->getSequence())
{
TIntermConstantUnion *argumentConstant = argument->getAsConstantUnion();
size_t argumentSize = argumentConstant->getType().getObjectSize();
const TConstantUnion *argumentUnionArray = argumentConstant->getUnionArrayPointer();
for (size_t i = 0u; i < argumentSize; ++i)
{
if (resultIndex >= resultSize)
break;
resultArray[resultIndex].cast(basicType, argumentUnionArray[i]);
++resultIndex;
}
}
ASSERT(resultIndex == resultSize);
return resultArray;
}
// static
TConstantUnion *TIntermConstantUnion::FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink) TConstantUnion *TIntermConstantUnion::FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink)
{ {
TOperator op = aggregate->getOp(); TOperator op = aggregate->getOp();
......
...@@ -345,6 +345,8 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -345,6 +345,8 @@ class TIntermConstantUnion : public TIntermTyped
TConstantUnion *foldUnaryWithDifferentReturnType(TOperator op, TInfoSink &infoSink); TConstantUnion *foldUnaryWithDifferentReturnType(TOperator op, TInfoSink &infoSink);
TConstantUnion *foldUnaryWithSameReturnType(TOperator op, TInfoSink &infoSink); TConstantUnion *foldUnaryWithSameReturnType(TOperator op, TInfoSink &infoSink);
static TConstantUnion *FoldAggregateConstructor(TIntermAggregate *aggregate,
TInfoSink &infoSink);
static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink); static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink);
protected: protected:
......
...@@ -469,33 +469,38 @@ TIntermTyped *TIntermediate::foldAggregateBuiltIn(TIntermAggregate *aggregate) ...@@ -469,33 +469,38 @@ TIntermTyped *TIntermediate::foldAggregateBuiltIn(TIntermAggregate *aggregate)
{ {
switch (aggregate->getOp()) switch (aggregate->getOp())
{ {
case EOpAtan: case EOpAtan:
case EOpPow: case EOpPow:
case EOpMod: case EOpMod:
case EOpMin: case EOpMin:
case EOpMax: case EOpMax:
case EOpClamp: case EOpClamp:
case EOpMix: case EOpMix:
case EOpStep: case EOpStep:
case EOpSmoothStep: case EOpSmoothStep:
case EOpMul: case EOpMul:
case EOpOuterProduct: case EOpOuterProduct:
case EOpLessThan: case EOpLessThan:
case EOpLessThanEqual: case EOpLessThanEqual:
case EOpGreaterThan: case EOpGreaterThan:
case EOpGreaterThanEqual: case EOpGreaterThanEqual:
case EOpVectorEqual: case EOpVectorEqual:
case EOpVectorNotEqual: case EOpVectorNotEqual:
case EOpDistance: case EOpDistance:
case EOpDot: case EOpDot:
case EOpCross: case EOpCross:
case EOpFaceForward: case EOpFaceForward:
case EOpReflect: case EOpReflect:
case EOpRefract: case EOpRefract:
return aggregate->fold(mInfoSink); return aggregate->fold(mInfoSink);
default: default:
// Constant folding not supported for the built-in. // TODO: Add support for folding array constructors
return nullptr; if (aggregate->isConstructor() && !aggregate->isArray())
{
return aggregate->fold(mInfoSink);
}
// Constant folding not supported for the built-in.
return nullptr;
} }
return nullptr; return nullptr;
......
...@@ -54,9 +54,6 @@ class TIntermediate ...@@ -54,9 +54,6 @@ class TIntermediate
int shaderVersion); int shaderVersion);
TIntermConstantUnion *addConstantUnion( TIntermConstantUnion *addConstantUnion(
TConstantUnion *constantUnion, const TType &type, const TSourceLoc &line); TConstantUnion *constantUnion, const TType &type, const TSourceLoc &line);
// TODO(zmo): Get rid of default value.
bool parseConstTree(const TSourceLoc &, TIntermNode *, TConstantUnion *,
TOperator, TType, bool singleConstantParam = false);
TIntermNode *addLoop(TLoopType, TIntermNode *, TIntermTyped *, TIntermTyped *, TIntermNode *addLoop(TLoopType, TIntermNode *, TIntermTyped *, TIntermTyped *,
TIntermNode *, const TSourceLoc &); TIntermNode *, const TSourceLoc &);
TIntermBranch *addBranch(TOperator, const TSourceLoc &); TIntermBranch *addBranch(TOperator, const TSourceLoc &);
......
...@@ -1369,17 +1369,6 @@ bool TParseContext::executeInitializer(const TSourceLoc &line, ...@@ -1369,17 +1369,6 @@ bool TParseContext::executeInitializer(const TSourceLoc &line,
return false; return false;
} }
bool TParseContext::areAllChildrenConstantFolded(TIntermAggregate *aggrNode)
{
ASSERT(aggrNode != nullptr);
for (TIntermNode *&node : *aggrNode->getSequence())
{
if (node->getAsConstantUnion() == nullptr)
return false;
}
return true;
}
TPublicType TParseContext::addFullySpecifiedType(TQualifier qualifier, TPublicType TParseContext::addFullySpecifiedType(TQualifier qualifier,
bool invariant, bool invariant,
TLayoutQualifier layoutQualifier, TLayoutQualifier layoutQualifier,
...@@ -2318,7 +2307,7 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments, ...@@ -2318,7 +2307,7 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments,
type->setPrecision(constructor->getPrecision()); type->setPrecision(constructor->getPrecision());
} }
TIntermTyped *constConstructor = foldConstConstructor(constructor, *type); TIntermTyped *constConstructor = intermediate.foldAggregateBuiltIn(constructor);
if (constConstructor) if (constConstructor)
{ {
return constConstructor; return constConstructor;
...@@ -2327,33 +2316,6 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments, ...@@ -2327,33 +2316,6 @@ TIntermTyped *TParseContext::addConstructor(TIntermNode *arguments,
return constructor; return constructor;
} }
TIntermTyped *TParseContext::foldConstConstructor(TIntermAggregate *aggrNode, const TType &type)
{
// TODO: Add support for folding array constructors
bool canBeFolded = areAllChildrenConstantFolded(aggrNode) && !type.isArray();
if (canBeFolded)
{
bool returnVal = false;
TConstantUnion *unionArray = new TConstantUnion[type.getObjectSize()];
if (aggrNode->getSequence()->size() == 1)
{
returnVal = intermediate.parseConstTree(aggrNode->getLine(), aggrNode, unionArray,
aggrNode->getOp(), type, true);
}
else
{
returnVal = intermediate.parseConstTree(aggrNode->getLine(), aggrNode, unionArray,
aggrNode->getOp(), type);
}
if (returnVal)
return 0;
return intermediate.addConstantUnion(unionArray, type, aggrNode->getLine());
}
return 0;
}
// //
// This function returns the tree representation for the vector field(s) being accessed from contant // This function returns the tree representation for the vector field(s) being accessed from contant
// vector. // vector.
......
...@@ -165,7 +165,6 @@ class TParseContext : angle::NonCopyable ...@@ -165,7 +165,6 @@ class TParseContext : angle::NonCopyable
void handlePragmaDirective(const TSourceLoc &loc, const char *name, const char *value, bool stdgl); void handlePragmaDirective(const TSourceLoc &loc, const char *name, const char *value, bool stdgl);
bool containsSampler(const TType &type); bool containsSampler(const TType &type);
bool areAllChildrenConstantFolded(TIntermAggregate *aggrNode);
const TFunction* findFunction( const TFunction* findFunction(
const TSourceLoc &line, TFunction *pfnCall, int inputShaderVersion, bool *builtIn = 0); const TSourceLoc &line, TFunction *pfnCall, int inputShaderVersion, bool *builtIn = 0);
bool executeInitializer(const TSourceLoc &line, bool executeInitializer(const TSourceLoc &line,
...@@ -247,7 +246,6 @@ class TParseContext : angle::NonCopyable ...@@ -247,7 +246,6 @@ class TParseContext : angle::NonCopyable
TOperator op, TOperator op,
TFunction *fnCall, TFunction *fnCall,
const TSourceLoc &line); const TSourceLoc &line);
TIntermTyped *foldConstConstructor(TIntermAggregate *aggrNode, const TType &type);
TIntermTyped *addConstVectorNode(TVectorFields &fields, TIntermTyped *addConstVectorNode(TVectorFields &fields,
TIntermConstantUnion *node, TIntermConstantUnion *node,
const TSourceLoc &line, const TSourceLoc &line,
......
//
// Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/translator/ParseContext.h"
//
// Use this class to carry along data from node to node in
// the traversal
//
class TConstTraverser : public TIntermTraverser
{
public:
TConstTraverser(TConstantUnion *cUnion, bool singleConstParam,
TOperator constructType, TInfoSink &sink, TType &t)
: TIntermTraverser(true, false, false),
error(false),
mIndex(0),
mUnionArray(cUnion),
mType(t),
mConstructorType(constructType),
mSingleConstantParam(singleConstParam),
mInfoSink(sink),
mSize(0),
mIsDiagonalMatrixInit(false),
mMatrixCols(0),
mMatrixRows(0)
{
}
bool error;
protected:
void visitSymbol(TIntermSymbol *) override;
void visitConstantUnion(TIntermConstantUnion *) override;
bool visitBinary(Visit visit, TIntermBinary *) override;
bool visitUnary(Visit visit, TIntermUnary *) override;
bool visitSelection(Visit visit, TIntermSelection *) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitLoop(Visit visit, TIntermLoop *) override;
bool visitBranch(Visit visit, TIntermBranch *) override;
size_t mIndex;
TConstantUnion *mUnionArray;
TType mType;
TOperator mConstructorType;
bool mSingleConstantParam;
TInfoSink &mInfoSink;
size_t mSize; // size of the constructor ( 4 for vec4)
bool mIsDiagonalMatrixInit;
int mMatrixCols; // columns of the matrix
int mMatrixRows; // rows of the matrix
};
//
// The rest of the file are the traversal functions. The last one
// is the one that starts the traversal.
//
// Return true from interior nodes to have the external traversal
// continue on to children. If you process children yourself,
// return false.
//
void TConstTraverser::visitSymbol(TIntermSymbol *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Symbol Node found in constant constructor");
return;
}
bool TConstTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
TQualifier qualifier = node->getType().getQualifier();
if (qualifier != EvqConst)
{
TString buf;
buf.append("'constructor' : assigning non-constant to ");
buf.append(mType.getCompleteString());
mInfoSink.info.message(EPrefixError, node->getLine(), buf.c_str());
error = true;
return false;
}
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Binary Node found in constant constructor");
return false;
}
bool TConstTraverser::visitUnary(Visit visit, TIntermUnary *node)
{
TString buf;
buf.append("'constructor' : assigning non-constant to ");
buf.append(mType.getCompleteString());
mInfoSink.info.message(EPrefixError, node->getLine(), buf.c_str());
error = true;
return false;
}
bool TConstTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
if (!node->isConstructor() && node->getOp() != EOpComma)
{
TString buf;
buf.append("'constructor' : assigning non-constant to ");
buf.append(mType.getCompleteString());
mInfoSink.info.message(EPrefixError, node->getLine(), buf.c_str());
error = true;
return false;
}
if (node->getSequence()->size() == 0)
{
error = true;
return false;
}
bool flag = node->getSequence()->size() == 1 &&
(*node->getSequence())[0]->getAsTyped()->getAsConstantUnion();
if (flag)
{
mSingleConstantParam = true;
mConstructorType = node->getOp();
mSize = node->getType().getObjectSize();
if (node->getType().isMatrix())
{
mIsDiagonalMatrixInit = true;
mMatrixCols = node->getType().getCols();
mMatrixRows = node->getType().getRows();
}
}
for (TIntermSequence::iterator p = node->getSequence()->begin();
p != node->getSequence()->end(); p++)
{
if (node->getOp() == EOpComma)
mIndex = 0;
(*p)->traverse(this);
}
if (flag)
{
mSingleConstantParam = false;
mConstructorType = EOpNull;
mSize = 0;
mIsDiagonalMatrixInit = false;
mMatrixCols = 0;
mMatrixRows = 0;
}
return false;
}
bool TConstTraverser::visitSelection(Visit visit, TIntermSelection *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Selection Node found in constant constructor");
error = true;
return false;
}
void TConstTraverser::visitConstantUnion(TIntermConstantUnion *node)
{
if (!node->getUnionArrayPointer())
{
// The constant was not initialized, this should already have been logged
ASSERT(mInfoSink.info.size() != 0);
return;
}
TConstantUnion *leftUnionArray = mUnionArray;
size_t instanceSize = mType.getObjectSize();
TBasicType basicType = mType.getBasicType();
if (mIndex >= instanceSize)
return;
if (!mSingleConstantParam)
{
size_t objectSize = node->getType().getObjectSize();
const TConstantUnion *rightUnionArray = node->getUnionArrayPointer();
for (size_t i=0; i < objectSize; i++)
{
if (mIndex >= instanceSize)
return;
leftUnionArray[mIndex].cast(basicType, rightUnionArray[i]);
mIndex++;
}
}
else
{
size_t totalSize = mIndex + mSize;
const TConstantUnion *rightUnionArray = node->getUnionArrayPointer();
if (!mIsDiagonalMatrixInit)
{
int count = 0;
for (size_t i = mIndex; i < totalSize; i++)
{
if (i >= instanceSize)
return;
leftUnionArray[i].cast(basicType, rightUnionArray[count]);
mIndex++;
if (node->getType().getObjectSize() > 1)
count++;
}
}
else
{
// for matrix diagonal constructors from a single scalar
for (int i = 0, col = 0; col < mMatrixCols; col++)
{
for (int row = 0; row < mMatrixRows; row++, i++)
{
if (col == row)
{
leftUnionArray[i].cast(basicType, rightUnionArray[0]);
}
else
{
leftUnionArray[i].setFConst(0.0f);
}
mIndex++;
}
}
}
}
}
bool TConstTraverser::visitLoop(Visit visit, TIntermLoop *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Loop Node found in constant constructor");
error = true;
return false;
}
bool TConstTraverser::visitBranch(Visit visit, TIntermBranch *node)
{
mInfoSink.info.message(EPrefixInternalError, node->getLine(),
"Branch Node found in constant constructor");
error = true;
return false;
}
//
// This function is the one to call externally to start the traversal.
// Individual functions can be initialized to 0 to skip processing of that
// type of node. It's children will still be processed.
//
bool TIntermediate::parseConstTree(
const TSourceLoc &line, TIntermNode *root, TConstantUnion *unionArray,
TOperator constructorType, TType t, bool singleConstantParam)
{
if (root == 0)
return false;
TConstTraverser it(unionArray, singleConstantParam, constructorType,
mInfoSink, t);
root->traverse(&it);
if (it.error)
return true;
else
return false;
}
...@@ -564,3 +564,136 @@ TEST_F(ConstantFoldingTest, FoldUnaryMinusOnUintLiteral) ...@@ -564,3 +564,136 @@ TEST_F(ConstantFoldingTest, FoldUnaryMinusOnUintLiteral)
compile(shaderString); compile(shaderString);
ASSERT_TRUE(constantFoundInAST(0xFFFFFFFFu)); ASSERT_TRUE(constantFoundInAST(0xFFFFFFFFu));
} }
// Test that constant mat2 initialization with a mat2 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingMat2)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(mat2(0.0, 1.0, 2.0, 3.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
0.0f, 1.0f,
2.0f, 3.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with an int parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingScalar)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(3);\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
3.0f, 0.0f,
0.0f, 3.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with a mix of parameters works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingMix)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(-1, vec2(0.0, 1.0), vec4(2.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
-1.0, 0.0f,
1.0f, 2.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with a mat3 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingMat3)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(mat3(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
0.0f, 1.0f,
3.0f, 4.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat4x3 initialization with a mat3x2 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat4x3ConstructorTakingMat3x2)
{
const std::string &shaderString =
"#version 300 es\n"
"precision mediump float;\n"
"uniform float mult;\n"
"out vec4 my_FragColor;\n"
"void main() {\n"
" const mat4x3 cm = mat4x3(mat3x2(1.0, 2.0,\n"
" 3.0, 4.0,\n"
" 5.0, 6.0));\n"
" mat4x3 m = cm * mult;\n"
" my_FragColor = vec4(m[0], m[1][0]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
1.0f, 2.0f, 0.0f,
3.0f, 4.0f, 0.0f,
5.0f, 6.0f, 1.0f,
0.0f, 0.0f, 0.0f
};
std::vector<float> result(outputElements, outputElements + 12);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
// Test that constant mat2 initialization with a vec4 parameter works correctly.
TEST_F(ConstantFoldingTest, FoldMat2ConstructorTakingVec4)
{
const std::string &shaderString =
"precision mediump float;\n"
"uniform float mult;\n"
"void main() {\n"
" const mat2 cm = mat2(vec4(0.0, 1.0, 2.0, 3.0));\n"
" mat2 m = cm * mult;\n"
" gl_FragColor = vec4(m[0], m[1]);\n"
"}\n";
compile(shaderString);
float outputElements[] =
{
0.0f, 1.0f,
2.0f, 3.0f
};
std::vector<float> result(outputElements, outputElements + 4);
ASSERT_TRUE(constantVectorFoundInAST(result));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment