Commit 2cb7b835 by Olli Etuaho Committed by Zhenyao Mo

Clean up binary operation constant folding code

Fix mixed up comments, remove unnecessary type conversions, clarify variable names and improve formatting in a few places. TEST=angle_unittests, WebGL conformance tests BUG=angleproject:913 Change-Id: Ice8fe3682d8e97f42747752302a1fba116132df4 Reviewed-on: https://chromium-review.googlesource.com/266843Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org>
parent a8c414ba
......@@ -666,7 +666,7 @@ bool TIntermBinary::promote(TInfoSink &infoSink)
// Returns the node to keep using, which may or may not be the node passed in.
//
TIntermTyped *TIntermConstantUnion::fold(
TOperator op, TIntermTyped *constantNode, TInfoSink &infoSink)
TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink)
{
ConstantUnion *unionArray = getUnionArrayPointer();
......@@ -675,39 +675,38 @@ TIntermTyped *TIntermConstantUnion::fold(
size_t objectSize = getType().getObjectSize();
if (constantNode)
if (rightNode)
{
// binary operations
TIntermConstantUnion *node = constantNode->getAsConstantUnion();
ConstantUnion *rightUnionArray = node->getUnionArrayPointer();
ConstantUnion *rightUnionArray = rightNode->getUnionArrayPointer();
TType returnType = getType();
if (!rightUnionArray)
return nullptr;
// for a case like float f = 1.2 + vec4(2,3,4,5);
if (constantNode->getType().getObjectSize() == 1 && objectSize > 1)
// for a case like float f = vec4(2, 3, 4, 5) + 1.2;
if (rightNode->getType().getObjectSize() == 1 && objectSize > 1)
{
rightUnionArray = new ConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; ++i)
{
rightUnionArray[i] = *node->getUnionArrayPointer();
rightUnionArray[i] = *rightNode->getUnionArrayPointer();
}
returnType = getType();
}
else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1)
else if (rightNode->getType().getObjectSize() > 1 && objectSize == 1)
{
// for a case like float f = vec4(2,3,4,5) + 1.2;
unionArray = new ConstantUnion[constantNode->getType().getObjectSize()];
for (size_t i = 0; i < constantNode->getType().getObjectSize(); ++i)
// for a case like float f = 1.2 + vec4(2, 3, 4, 5);
unionArray = new ConstantUnion[rightNode->getType().getObjectSize()];
for (size_t i = 0; i < rightNode->getType().getObjectSize(); ++i)
{
unionArray[i] = *getUnionArrayPointer();
}
returnType = node->getType();
objectSize = constantNode->getType().getObjectSize();
returnType = rightNode->getType();
objectSize = rightNode->getType().getObjectSize();
}
ConstantUnion *tempConstArray = NULL;
ConstantUnion *tempConstArray = nullptr;
TIntermConstantUnion *tempNode;
bool boolNodeFlag = false;
......@@ -735,7 +734,7 @@ TIntermTyped *TIntermConstantUnion::fold(
case EOpMatrixTimesMatrix:
{
if (getType().getBasicType() != EbtFloat ||
node->getBasicType() != EbtFloat)
rightNode->getBasicType() != EbtFloat)
{
infoSink.info.message(
EPrefixInternalError, getLine(),
......@@ -745,12 +744,12 @@ TIntermTyped *TIntermConstantUnion::fold(
const int leftCols = getCols();
const int leftRows = getRows();
const int rightCols = constantNode->getType().getCols();
const int rightRows = constantNode->getType().getRows();
const int rightCols = rightNode->getType().getCols();
const int rightRows = rightNode->getType().getRows();
const int resultCols = rightCols;
const int resultRows = leftRows;
tempConstArray = new ConstantUnion[resultCols*resultRows];
tempConstArray = new ConstantUnion[resultCols * resultRows];
for (int row = 0; row < resultRows; row++)
{
for (int column = 0; column < resultCols; column++)
......@@ -862,7 +861,7 @@ TIntermTyped *TIntermConstantUnion::fold(
case EOpMatrixTimesVector:
{
if (node->getBasicType() != EbtFloat)
if (rightNode->getBasicType() != EbtFloat)
{
infoSink.info.message(
EPrefixInternalError, getLine(),
......@@ -887,7 +886,7 @@ TIntermTyped *TIntermConstantUnion::fold(
}
}
returnType = node->getType();
returnType = rightNode->getType();
returnType.setPrimarySize(static_cast<unsigned char>(matrixRows));
tempNode = new TIntermConstantUnion(tempConstArray, returnType);
......@@ -906,8 +905,8 @@ TIntermTyped *TIntermConstantUnion::fold(
return nullptr;
}
const int matrixCols = constantNode->getType().getCols();
const int matrixRows = constantNode->getType().getRows();
const int matrixCols = rightNode->getType().getCols();
const int matrixRows = rightNode->getType().getRows();
tempConstArray = new ConstantUnion[matrixCols];
......@@ -1035,8 +1034,8 @@ TIntermTyped *TIntermConstantUnion::fold(
case EOpEqual:
if (getType().getBasicType() == EbtStruct)
{
if (!CompareStructure(node->getType(),
node->getUnionArrayPointer(),
if (!CompareStructure(rightNode->getType(),
rightNode->getUnionArrayPointer(),
unionArray))
{
boolNodeFlag = true;
......@@ -1073,8 +1072,8 @@ TIntermTyped *TIntermConstantUnion::fold(
case EOpNotEqual:
if (getType().getBasicType() == EbtStruct)
{
if (CompareStructure(node->getType(),
node->getUnionArrayPointer(),
if (CompareStructure(rightNode->getType(),
rightNode->getUnionArrayPointer(),
unionArray))
{
boolNodeFlag = true;
......
......@@ -293,7 +293,7 @@ class TIntermConstantUnion : public TIntermTyped
virtual void traverse(TIntermTraverser *);
virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
TIntermTyped *fold(TOperator, TIntermTyped *, TInfoSink &);
TIntermTyped *fold(TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink);
protected:
ConstantUnion *mUnionArrayPointer;
......
......@@ -143,7 +143,7 @@ TIntermTyped *TIntermediate::addUnaryMath(
if (childTempConstant)
{
TIntermTyped *newChild = childTempConstant->fold(op, 0, mInfoSink);
TIntermTyped *newChild = childTempConstant->fold(op, nullptr, mInfoSink);
if (newChild)
return newChild;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment