Commit b43846ee by Olli Etuaho

Unify aggregate operator folding with other constant folding

Setting the type for folded aggregate nodes should work in a similar way as other constant folding. Common functionality between the different folding functions is refactored into a single function. TEST=dEQP-GLES3.functional.shaders.constant_expressions.* BUG=angleproject:817 Change-Id: Ie0be561f4a30e52e52d570ff0b2bdb426f6e4f7a Reviewed-on: https://chromium-review.googlesource.com/275186Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org>
parent 35b08e3c
...@@ -187,6 +187,18 @@ float VectorDotProduct(TConstantUnion *paramArray1, TConstantUnion *paramArray2, ...@@ -187,6 +187,18 @@ float VectorDotProduct(TConstantUnion *paramArray1, TConstantUnion *paramArray2,
return result; return result;
} }
TIntermTyped *CreateFoldedNode(TConstantUnion *constArray, const TIntermTyped *originalNode)
{
if (constArray == nullptr)
{
return nullptr;
}
TIntermTyped *folded = new TIntermConstantUnion(constArray, originalNode->getType());
folded->getTypePointer()->setQualifier(EvqConst);
folded->setLine(originalNode->getLine());
return folded;
}
} // namespace anonymous } // namespace anonymous
...@@ -756,14 +768,7 @@ TIntermTyped *TIntermBinary::fold(TInfoSink &infoSink) ...@@ -756,14 +768,7 @@ TIntermTyped *TIntermBinary::fold(TInfoSink &infoSink)
return nullptr; return nullptr;
} }
TConstantUnion *constArray = leftConstant->foldBinary(mOp, rightConstant, infoSink); TConstantUnion *constArray = leftConstant->foldBinary(mOp, rightConstant, infoSink);
if (constArray == nullptr) return CreateFoldedNode(constArray, this);
{
return nullptr;
}
TIntermTyped *folded = new TIntermConstantUnion(constArray, getType());
folded->getTypePointer()->setQualifier(EvqConst);
folded->setLine(getLine());
return folded;
} }
TIntermTyped *TIntermUnary::fold(TInfoSink &infoSink) TIntermTyped *TIntermUnary::fold(TInfoSink &infoSink)
...@@ -774,14 +779,21 @@ TIntermTyped *TIntermUnary::fold(TInfoSink &infoSink) ...@@ -774,14 +779,21 @@ TIntermTyped *TIntermUnary::fold(TInfoSink &infoSink)
return nullptr; return nullptr;
} }
TConstantUnion *constArray = operandConstant->foldUnary(mOp, infoSink); TConstantUnion *constArray = operandConstant->foldUnary(mOp, infoSink);
if (constArray == nullptr) return CreateFoldedNode(constArray, this);
}
TIntermTyped *TIntermAggregate::fold(TInfoSink &infoSink)
{
// Make sure that all params are constant before actual constant folding.
for (auto *param : *getSequence())
{ {
return nullptr; if (param->getAsConstantUnion() == nullptr)
{
return nullptr;
}
} }
TIntermTyped *folded = new TIntermConstantUnion(constArray, getType()); TConstantUnion *constArray = TIntermConstantUnion::FoldAggregateBuiltIn(this, infoSink);
folded->getTypePointer()->setQualifier(EvqConst); return CreateFoldedNode(constArray, this);
folded->setLine(getLine());
return folded;
} }
// //
...@@ -1591,21 +1603,20 @@ bool TIntermConstantUnion::foldFloatTypeUnary(const TConstantUnion &parameter, F ...@@ -1591,21 +1603,20 @@ bool TIntermConstantUnion::foldFloatTypeUnary(const TConstantUnion &parameter, F
} }
// static // static
TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAggregate *aggregate, TInfoSink &infoSink) TConstantUnion *TIntermConstantUnion::FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink)
{ {
TOperator op = aggregate->getOp();
TIntermSequence *sequence = aggregate->getSequence(); TIntermSequence *sequence = aggregate->getSequence();
unsigned int paramsCount = sequence->size(); unsigned int paramsCount = sequence->size();
std::vector<TConstantUnion *> unionArrays(paramsCount); std::vector<TConstantUnion *> unionArrays(paramsCount);
std::vector<size_t> objectSizes(paramsCount); std::vector<size_t> objectSizes(paramsCount);
TType *maxSizeType = nullptr; size_t maxObjectSize = 0;
TBasicType basicType = EbtVoid; TBasicType basicType = EbtVoid;
TSourceLoc loc; TSourceLoc loc;
for (unsigned int i = 0; i < paramsCount; i++) for (unsigned int i = 0; i < paramsCount; i++)
{ {
TIntermConstantUnion *paramConstant = (*sequence)[i]->getAsConstantUnion(); TIntermConstantUnion *paramConstant = (*sequence)[i]->getAsConstantUnion();
// Make sure that all params are constant before actual constant folding. ASSERT(paramConstant != nullptr); // Should be checked already.
if (!paramConstant)
return nullptr;
if (i == 0) if (i == 0)
{ {
...@@ -1614,18 +1625,15 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1614,18 +1625,15 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
} }
unionArrays[i] = paramConstant->getUnionArrayPointer(); unionArrays[i] = paramConstant->getUnionArrayPointer();
objectSizes[i] = paramConstant->getType().getObjectSize(); objectSizes[i] = paramConstant->getType().getObjectSize();
if (maxSizeType == nullptr || (objectSizes[i] >= maxSizeType->getObjectSize())) if (objectSizes[i] > maxObjectSize)
maxSizeType = paramConstant->getTypePointer(); maxObjectSize = objectSizes[i];
} }
size_t maxObjectSize = maxSizeType->getObjectSize();
for (unsigned int i = 0; i < paramsCount; i++) for (unsigned int i = 0; i < paramsCount; i++)
if (objectSizes[i] != maxObjectSize) if (objectSizes[i] != maxObjectSize)
unionArrays[i] = Vectorize(*unionArrays[i], maxObjectSize); unionArrays[i] = Vectorize(*unionArrays[i], maxObjectSize);
TConstantUnion *tempConstArray = nullptr; TConstantUnion *resultArray = nullptr;
TIntermConstantUnion *tempNode = nullptr;
TType returnType = *maxSizeType;
if (paramsCount == 2) if (paramsCount == 2)
{ {
// //
...@@ -1637,16 +1645,16 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1637,16 +1645,16 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float y = unionArrays[0][i].getFConst(); float y = unionArrays[0][i].getFConst();
float x = unionArrays[1][i].getFConst(); float x = unionArrays[1][i].getFConst();
// Results are undefined if x and y are both 0. // Results are undefined if x and y are both 0.
if (x == 0.0f && y == 0.0f) if (x == 0.0f && y == 0.0f)
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
else else
tempConstArray[i].setFConst(atan2f(y, x)); resultArray[i].setFConst(atan2f(y, x));
} }
} }
else else
...@@ -1658,7 +1666,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1658,7 +1666,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float x = unionArrays[0][i].getFConst(); float x = unionArrays[0][i].getFConst();
...@@ -1666,11 +1674,11 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1666,11 +1674,11 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
// Results are undefined if x < 0. // Results are undefined if x < 0.
// Results are undefined if x = 0 and y <= 0. // Results are undefined if x = 0 and y <= 0.
if (x < 0.0f) if (x < 0.0f)
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
else if (x == 0.0f && y <= 0.0f) else if (x == 0.0f && y <= 0.0f)
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
else else
tempConstArray[i].setFConst(powf(x, y)); resultArray[i].setFConst(powf(x, y));
} }
} }
else else
...@@ -1682,12 +1690,12 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1682,12 +1690,12 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float x = unionArrays[0][i].getFConst(); float x = unionArrays[0][i].getFConst();
float y = unionArrays[1][i].getFConst(); float y = unionArrays[1][i].getFConst();
tempConstArray[i].setFConst(x - y * floorf(x / y)); resultArray[i].setFConst(x - y * floorf(x / y));
} }
} }
else else
...@@ -1697,19 +1705,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1697,19 +1705,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpMin: case EOpMin:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setFConst(std::min(unionArrays[0][i].getFConst(), unionArrays[1][i].getFConst())); resultArray[i].setFConst(std::min(unionArrays[0][i].getFConst(), unionArrays[1][i].getFConst()));
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setIConst(std::min(unionArrays[0][i].getIConst(), unionArrays[1][i].getIConst())); resultArray[i].setIConst(std::min(unionArrays[0][i].getIConst(), unionArrays[1][i].getIConst()));
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setUConst(std::min(unionArrays[0][i].getUConst(), unionArrays[1][i].getUConst())); resultArray[i].setUConst(std::min(unionArrays[0][i].getUConst(), unionArrays[1][i].getUConst()));
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1721,19 +1729,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1721,19 +1729,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpMax: case EOpMax:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setFConst(std::max(unionArrays[0][i].getFConst(), unionArrays[1][i].getFConst())); resultArray[i].setFConst(std::max(unionArrays[0][i].getFConst(), unionArrays[1][i].getFConst()));
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setIConst(std::max(unionArrays[0][i].getIConst(), unionArrays[1][i].getIConst())); resultArray[i].setIConst(std::max(unionArrays[0][i].getIConst(), unionArrays[1][i].getIConst()));
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setUConst(std::max(unionArrays[0][i].getUConst(), unionArrays[1][i].getUConst())); resultArray[i].setUConst(std::max(unionArrays[0][i].getUConst(), unionArrays[1][i].getUConst()));
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1747,9 +1755,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1747,9 +1755,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
tempConstArray[i].setFConst(unionArrays[1][i].getFConst() < unionArrays[0][i].getFConst() ? 0.0f : 1.0f); resultArray[i].setFConst(unionArrays[1][i].getFConst() < unionArrays[0][i].getFConst() ? 0.0f : 1.0f);
} }
else else
UNREACHABLE(); UNREACHABLE();
...@@ -1758,19 +1766,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1758,19 +1766,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpLessThan: case EOpLessThan:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setBConst(unionArrays[0][i].getFConst() < unionArrays[1][i].getFConst()); resultArray[i].setBConst(unionArrays[0][i].getFConst() < unionArrays[1][i].getFConst());
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setBConst(unionArrays[0][i].getIConst() < unionArrays[1][i].getIConst()); resultArray[i].setBConst(unionArrays[0][i].getIConst() < unionArrays[1][i].getIConst());
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setBConst(unionArrays[0][i].getUConst() < unionArrays[1][i].getUConst()); resultArray[i].setBConst(unionArrays[0][i].getUConst() < unionArrays[1][i].getUConst());
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1782,19 +1790,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1782,19 +1790,19 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpLessThanEqual: case EOpLessThanEqual:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setBConst(unionArrays[0][i].getFConst() <= unionArrays[1][i].getFConst()); resultArray[i].setBConst(unionArrays[0][i].getFConst() <= unionArrays[1][i].getFConst());
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setBConst(unionArrays[0][i].getIConst() <= unionArrays[1][i].getIConst()); resultArray[i].setBConst(unionArrays[0][i].getIConst() <= unionArrays[1][i].getIConst());
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setBConst(unionArrays[0][i].getUConst() <= unionArrays[1][i].getUConst()); resultArray[i].setBConst(unionArrays[0][i].getUConst() <= unionArrays[1][i].getUConst());
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1806,43 +1814,43 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1806,43 +1814,43 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpGreaterThan: case EOpGreaterThan:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setBConst(unionArrays[0][i].getFConst() > unionArrays[1][i].getFConst()); resultArray[i].setBConst(unionArrays[0][i].getFConst() > unionArrays[1][i].getFConst());
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setBConst(unionArrays[0][i].getIConst() > unionArrays[1][i].getIConst()); resultArray[i].setBConst(unionArrays[0][i].getIConst() > unionArrays[1][i].getIConst());
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setBConst(unionArrays[0][i].getUConst() > unionArrays[1][i].getUConst()); resultArray[i].setBConst(unionArrays[0][i].getUConst() > unionArrays[1][i].getUConst());
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
break; break;
} }
} }
} }
break; break;
case EOpGreaterThanEqual: case EOpGreaterThanEqual:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setBConst(unionArrays[0][i].getFConst() >= unionArrays[1][i].getFConst()); resultArray[i].setBConst(unionArrays[0][i].getFConst() >= unionArrays[1][i].getFConst());
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setBConst(unionArrays[0][i].getIConst() >= unionArrays[1][i].getIConst()); resultArray[i].setBConst(unionArrays[0][i].getIConst() >= unionArrays[1][i].getIConst());
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setBConst(unionArrays[0][i].getUConst() >= unionArrays[1][i].getUConst()); resultArray[i].setBConst(unionArrays[0][i].getUConst() >= unionArrays[1][i].getUConst());
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1854,22 +1862,22 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1854,22 +1862,22 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpVectorEqual: case EOpVectorEqual:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setBConst(unionArrays[0][i].getFConst() == unionArrays[1][i].getFConst()); resultArray[i].setBConst(unionArrays[0][i].getFConst() == unionArrays[1][i].getFConst());
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setBConst(unionArrays[0][i].getIConst() == unionArrays[1][i].getIConst()); resultArray[i].setBConst(unionArrays[0][i].getIConst() == unionArrays[1][i].getIConst());
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setBConst(unionArrays[0][i].getUConst() == unionArrays[1][i].getUConst()); resultArray[i].setBConst(unionArrays[0][i].getUConst() == unionArrays[1][i].getUConst());
break; break;
case EbtBool: case EbtBool:
tempConstArray[i].setBConst(unionArrays[0][i].getBConst() == unionArrays[1][i].getBConst()); resultArray[i].setBConst(unionArrays[0][i].getBConst() == unionArrays[1][i].getBConst());
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1881,22 +1889,22 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1881,22 +1889,22 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpVectorNotEqual: case EOpVectorNotEqual:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
{ {
case EbtFloat: case EbtFloat:
tempConstArray[i].setBConst(unionArrays[0][i].getFConst() != unionArrays[1][i].getFConst()); resultArray[i].setBConst(unionArrays[0][i].getFConst() != unionArrays[1][i].getFConst());
break; break;
case EbtInt: case EbtInt:
tempConstArray[i].setBConst(unionArrays[0][i].getIConst() != unionArrays[1][i].getIConst()); resultArray[i].setBConst(unionArrays[0][i].getIConst() != unionArrays[1][i].getIConst());
break; break;
case EbtUInt: case EbtUInt:
tempConstArray[i].setBConst(unionArrays[0][i].getUConst() != unionArrays[1][i].getUConst()); resultArray[i].setBConst(unionArrays[0][i].getUConst() != unionArrays[1][i].getUConst());
break; break;
case EbtBool: case EbtBool:
tempConstArray[i].setBConst(unionArrays[0][i].getBConst() != unionArrays[1][i].getBConst()); resultArray[i].setBConst(unionArrays[0][i].getBConst() != unionArrays[1][i].getBConst());
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -1910,24 +1918,25 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1910,24 +1918,25 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
TConstantUnion *distanceArray = new TConstantUnion[maxObjectSize]; TConstantUnion *distanceArray = new TConstantUnion[maxObjectSize];
tempConstArray = new TConstantUnion(); resultArray = new TConstantUnion();
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float x = unionArrays[0][i].getFConst(); float x = unionArrays[0][i].getFConst();
float y = unionArrays[1][i].getFConst(); float y = unionArrays[1][i].getFConst();
distanceArray[i].setFConst(x - y); distanceArray[i].setFConst(x - y);
} }
tempConstArray->setFConst(VectorLength(distanceArray, maxObjectSize)); resultArray->setFConst(VectorLength(distanceArray, maxObjectSize));
} }
else else
UNREACHABLE(); UNREACHABLE();
break; break;
case EOpDot: case EOpDot:
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion(); resultArray = new TConstantUnion();
tempConstArray->setFConst(VectorDotProduct(unionArrays[0], unionArrays[1], maxObjectSize)); resultArray->setFConst(VectorDotProduct(unionArrays[0], unionArrays[1], maxObjectSize));
} }
else else
UNREACHABLE(); UNREACHABLE();
...@@ -1936,16 +1945,16 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1936,16 +1945,16 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
case EOpCross: case EOpCross:
if (basicType == EbtFloat && maxObjectSize == 3) if (basicType == EbtFloat && maxObjectSize == 3)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
float x0 = unionArrays[0][0].getFConst(); float x0 = unionArrays[0][0].getFConst();
float x1 = unionArrays[0][1].getFConst(); float x1 = unionArrays[0][1].getFConst();
float x2 = unionArrays[0][2].getFConst(); float x2 = unionArrays[0][2].getFConst();
float y0 = unionArrays[1][0].getFConst(); float y0 = unionArrays[1][0].getFConst();
float y1 = unionArrays[1][1].getFConst(); float y1 = unionArrays[1][1].getFConst();
float y2 = unionArrays[1][2].getFConst(); float y2 = unionArrays[1][2].getFConst();
tempConstArray[0].setFConst(x1 * y2 - y1 * x2); resultArray[0].setFConst(x1 * y2 - y1 * x2);
tempConstArray[1].setFConst(x2 * y0 - y2 * x0); resultArray[1].setFConst(x2 * y0 - y2 * x0);
tempConstArray[2].setFConst(x0 * y1 - y0 * x1); resultArray[2].setFConst(x0 * y1 - y0 * x1);
} }
else else
UNREACHABLE(); UNREACHABLE();
...@@ -1957,13 +1966,13 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1957,13 +1966,13 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
// genType reflect (genType I, genType N) : // genType reflect (genType I, genType N) :
// For the incident vector I and surface orientation N, returns the reflection direction: // For the incident vector I and surface orientation N, returns the reflection direction:
// I - 2 * dot(N, I) * N. // I - 2 * dot(N, I) * N.
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize); float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize);
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float result = unionArrays[0][i].getFConst() - float result = unionArrays[0][i].getFConst() -
2.0f * dotProduct * unionArrays[1][i].getFConst(); 2.0f * dotProduct * unionArrays[1][i].getFConst();
tempConstArray[i].setFConst(result); resultArray[i].setFConst(result);
} }
} }
else else
...@@ -1985,7 +1994,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1985,7 +1994,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
case EOpClamp: case EOpClamp:
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
switch (basicType) switch (basicType)
...@@ -1997,9 +2006,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -1997,9 +2006,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
float max = unionArrays[2][i].getFConst(); float max = unionArrays[2][i].getFConst();
// Results are undefined if min > max. // Results are undefined if min > max.
if (min > max) if (min > max)
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
else else
tempConstArray[i].setFConst(gl::clamp(x, min, max)); resultArray[i].setFConst(gl::clamp(x, min, max));
} }
break; break;
case EbtInt: case EbtInt:
...@@ -2009,9 +2018,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2009,9 +2018,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
int max = unionArrays[2][i].getIConst(); int max = unionArrays[2][i].getIConst();
// Results are undefined if min > max. // Results are undefined if min > max.
if (min > max) if (min > max)
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
else else
tempConstArray[i].setIConst(gl::clamp(x, min, max)); resultArray[i].setIConst(gl::clamp(x, min, max));
} }
break; break;
case EbtUInt: case EbtUInt:
...@@ -2021,9 +2030,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2021,9 +2030,9 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
unsigned int max = unionArrays[2][i].getUConst(); unsigned int max = unionArrays[2][i].getUConst();
// Results are undefined if min > max. // Results are undefined if min > max.
if (min > max) if (min > max)
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
else else
tempConstArray[i].setUConst(gl::clamp(x, min, max)); resultArray[i].setUConst(gl::clamp(x, min, max));
} }
break; break;
default: default:
...@@ -2038,7 +2047,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2038,7 +2047,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float x = unionArrays[0][i].getFConst(); float x = unionArrays[0][i].getFConst();
...@@ -2048,7 +2057,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2048,7 +2057,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
// Returns the linear blend of x and y, i.e., x * (1 - a) + y * a. // Returns the linear blend of x and y, i.e., x * (1 - a) + y * a.
float a = unionArrays[2][i].getFConst(); float a = unionArrays[2][i].getFConst();
tempConstArray[i].setFConst(x * (1.0f - a) + y * a); resultArray[i].setFConst(x * (1.0f - a) + y * a);
} }
else // 3rd parameter is EbtBool else // 3rd parameter is EbtBool
{ {
...@@ -2057,7 +2066,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2057,7 +2066,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
// For a component of a that is false, the corresponding component of x is returned. // For a component of a that is false, the corresponding component of x is returned.
// For a component of a that is true, the corresponding component of y is returned. // For a component of a that is true, the corresponding component of y is returned.
bool a = unionArrays[2][i].getBConst(); bool a = unionArrays[2][i].getBConst();
tempConstArray[i].setFConst(a ? y : x); resultArray[i].setFConst(a ? y : x);
} }
} }
} }
...@@ -2070,7 +2079,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2070,7 +2079,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
if (basicType == EbtFloat) if (basicType == EbtFloat)
{ {
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float edge0 = unionArrays[0][i].getFConst(); float edge0 = unionArrays[0][i].getFConst();
...@@ -2079,14 +2088,14 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2079,14 +2088,14 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
// Results are undefined if edge0 >= edge1. // Results are undefined if edge0 >= edge1.
if (edge0 >= edge1) if (edge0 >= edge1)
{ {
UndefinedConstantFoldingError(loc, op, basicType, infoSink, &tempConstArray[i]); UndefinedConstantFoldingError(loc, op, basicType, infoSink, &resultArray[i]);
} }
else else
{ {
// Returns 0.0 if x <= edge0 and 1.0 if x >= edge1 and performs smooth // Returns 0.0 if x <= edge0 and 1.0 if x >= edge1 and performs smooth
// Hermite interpolation between 0 and 1 when edge0 < x < edge1. // Hermite interpolation between 0 and 1 when edge0 < x < edge1.
float t = gl::clamp((x - edge0) / (edge1 - edge0), 0.0f, 1.0f); float t = gl::clamp((x - edge0) / (edge1 - edge0), 0.0f, 1.0f);
tempConstArray[i].setFConst(t * t * (3.0f - 2.0f * t)); resultArray[i].setFConst(t * t * (3.0f - 2.0f * t));
} }
} }
} }
...@@ -2100,14 +2109,14 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2100,14 +2109,14 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
{ {
// genType faceforward(genType N, genType I, genType Nref) : // genType faceforward(genType N, genType I, genType Nref) :
// If dot(Nref, I) < 0 return N, otherwise return -N. // If dot(Nref, I) < 0 return N, otherwise return -N.
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
float dotProduct = VectorDotProduct(unionArrays[2], unionArrays[1], maxObjectSize); float dotProduct = VectorDotProduct(unionArrays[2], unionArrays[1], maxObjectSize);
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
if (dotProduct < 0) if (dotProduct < 0)
tempConstArray[i].setFConst(unionArrays[0][i].getFConst()); resultArray[i].setFConst(unionArrays[0][i].getFConst());
else else
tempConstArray[i].setFConst(-unionArrays[0][i].getFConst()); resultArray[i].setFConst(-unionArrays[0][i].getFConst());
} }
} }
else else
...@@ -2125,16 +2134,16 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2125,16 +2134,16 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
// return genType(0.0) // return genType(0.0)
// else // else
// return eta * I - (eta * dot(N, I) + sqrt(k)) * N // return eta * I - (eta * dot(N, I) + sqrt(k)) * N
tempConstArray = new TConstantUnion[maxObjectSize]; resultArray = new TConstantUnion[maxObjectSize];
float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize); float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize);
for (size_t i = 0; i < maxObjectSize; i++) for (size_t i = 0; i < maxObjectSize; i++)
{ {
float eta = unionArrays[2][i].getFConst(); float eta = unionArrays[2][i].getFConst();
float k = 1.0f - eta * eta * (1.0f - dotProduct * dotProduct); float k = 1.0f - eta * eta * (1.0f - dotProduct * dotProduct);
if (k < 0.0f) if (k < 0.0f)
tempConstArray[i].setFConst(0.0f); resultArray[i].setFConst(0.0f);
else else
tempConstArray[i].setFConst(eta * unionArrays[0][i].getFConst() - resultArray[i].setFConst(eta * unionArrays[0][i].getFConst() -
(eta * dotProduct + sqrtf(k)) * unionArrays[1][i].getFConst()); (eta * dotProduct + sqrtf(k)) * unionArrays[1][i].getFConst());
} }
} }
...@@ -2148,13 +2157,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg ...@@ -2148,13 +2157,7 @@ TIntermTyped *TIntermConstantUnion::FoldAggregateBuiltIn(TOperator op, TIntermAg
return nullptr; return nullptr;
} }
} }
return resultArray;
if (tempConstArray)
{
tempNode = new TIntermConstantUnion(tempConstArray, returnType);
tempNode->setLine(loc);
}
return tempNode;
} }
// static // static
......
...@@ -302,7 +302,7 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -302,7 +302,7 @@ class TIntermConstantUnion : public TIntermTyped
TConstantUnion *foldBinary(TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink); TConstantUnion *foldBinary(TOperator op, TIntermConstantUnion *rightNode, TInfoSink &infoSink);
TConstantUnion *foldUnary(TOperator op, TInfoSink &infoSink); TConstantUnion *foldUnary(TOperator op, TInfoSink &infoSink);
static TIntermTyped *FoldAggregateBuiltIn(TOperator op, TIntermAggregate *aggregate, TInfoSink &infoSink); static TConstantUnion *FoldAggregateBuiltIn(TIntermAggregate *aggregate, TInfoSink &infoSink);
protected: protected:
TConstantUnion *mUnionArrayPointer; TConstantUnion *mUnionArrayPointer;
...@@ -443,6 +443,7 @@ class TIntermAggregate : public TIntermOperator ...@@ -443,6 +443,7 @@ class TIntermAggregate : public TIntermOperator
bool insertChildNodes(TIntermSequence::size_type position, TIntermSequence insertions); bool insertChildNodes(TIntermSequence::size_type position, TIntermSequence insertions);
// Conservatively assume function calls and other aggregate operators have side-effects // Conservatively assume function calls and other aggregate operators have side-effects
virtual bool hasSideEffects() const { return true; } virtual bool hasSideEffects() const { return true; }
TIntermTyped *fold(TInfoSink &infoSink);
TIntermSequence *getSequence() { return &mSequence; } TIntermSequence *getSequence() { return &mSequence; }
......
...@@ -441,9 +441,9 @@ bool TIntermediate::postProcess(TIntermNode *root) ...@@ -441,9 +441,9 @@ bool TIntermediate::postProcess(TIntermNode *root)
return true; return true;
} }
TIntermTyped *TIntermediate::foldAggregateBuiltIn(TOperator op, TIntermAggregate *aggregate) TIntermTyped *TIntermediate::foldAggregateBuiltIn(TIntermAggregate *aggregate)
{ {
switch (op) switch (aggregate->getOp())
{ {
case EOpAtan: case EOpAtan:
case EOpPow: case EOpPow:
...@@ -466,7 +466,7 @@ TIntermTyped *TIntermediate::foldAggregateBuiltIn(TOperator op, TIntermAggregate ...@@ -466,7 +466,7 @@ TIntermTyped *TIntermediate::foldAggregateBuiltIn(TOperator op, TIntermAggregate
case EOpFaceForward: case EOpFaceForward:
case EOpReflect: case EOpReflect:
case EOpRefract: case EOpRefract:
return TIntermConstantUnion::FoldAggregateBuiltIn(op, aggregate, mInfoSink); return aggregate->fold(mInfoSink);
default: default:
// Constant folding not supported for the built-in. // Constant folding not supported for the built-in.
return nullptr; return nullptr;
......
...@@ -64,7 +64,7 @@ class TIntermediate ...@@ -64,7 +64,7 @@ class TIntermediate
static void outputTree(TIntermNode *, TInfoSinkBase &); static void outputTree(TIntermNode *, TInfoSinkBase &);
TIntermTyped *foldAggregateBuiltIn(TOperator op, TIntermAggregate *aggregate); TIntermTyped *foldAggregateBuiltIn(TIntermAggregate *aggregate);
private: private:
void operator=(TIntermediate &); // prevent assignments void operator=(TIntermediate &); // prevent assignments
......
...@@ -3512,19 +3512,20 @@ TIntermTyped *TParseContext::addFunctionCallOrMethod(TFunction *fnCall, TIntermN ...@@ -3512,19 +3512,20 @@ TIntermTyped *TParseContext::addFunctionCallOrMethod(TFunction *fnCall, TIntermN
TIntermAggregate *aggregate = intermediate.setAggregateOperator(paramNode, op, loc); TIntermAggregate *aggregate = intermediate.setAggregateOperator(paramNode, op, loc);
aggregate->setType(fnCandidate->getReturnType()); aggregate->setType(fnCandidate->getReturnType());
aggregate->setPrecisionFromChildren(); aggregate->setPrecisionFromChildren();
callNode = aggregate;
// Some built-in functions have out parameters too. // Some built-in functions have out parameters too.
functionCallLValueErrorCheck(fnCandidate, aggregate); functionCallLValueErrorCheck(fnCandidate, aggregate);
// See if we can constant fold a built-in. // See if we can constant fold a built-in.
TIntermTyped *foldedNode = intermediate.foldAggregateBuiltIn(op, aggregate); TIntermTyped *foldedNode = intermediate.foldAggregateBuiltIn(aggregate);
if (foldedNode) if (foldedNode)
{ {
foldedNode->setType(callNode->getType());
foldedNode->getTypePointer()->setQualifier(EvqConst);
callNode = foldedNode; callNode = foldedNode;
} }
else
{
callNode = aggregate;
}
} }
} }
else else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment