Commit 1dded803 by Olli Etuaho Committed by Commit Bot

Check multiplication validity in ParseContext

This improves separation of responsibilities in the code: ParseContext should handle operand type validation, while TIntermBinary::promote should ideally only determine the type of the node based on the operation and operands. BUG=angleproject:952 TEST=angle_unittests Change-Id: I9a8d8ede21cdf35de631623a62194c0da5c604d2 Reviewed-on: https://chromium-review.googlesource.com/372622 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent 9949f04d
...@@ -34,41 +34,6 @@ TPrecision GetHigherPrecision(TPrecision left, TPrecision right) ...@@ -34,41 +34,6 @@ TPrecision GetHigherPrecision(TPrecision left, TPrecision right)
return left > right ? left : right; return left > right ? left : right;
} }
bool ValidateMultiplication(TOperator op, const TType &left, const TType &right)
{
switch (op)
{
case EOpMul:
case EOpMulAssign:
return left.getNominalSize() == right.getNominalSize() &&
left.getSecondarySize() == right.getSecondarySize();
case EOpVectorTimesScalar:
case EOpVectorTimesScalarAssign:
return true;
case EOpVectorTimesMatrix:
return left.getNominalSize() == right.getRows();
case EOpVectorTimesMatrixAssign:
return left.getNominalSize() == right.getRows() &&
left.getNominalSize() == right.getCols();
case EOpMatrixTimesVector:
return left.getCols() == right.getNominalSize();
case EOpMatrixTimesScalar:
case EOpMatrixTimesScalarAssign:
return true;
case EOpMatrixTimesMatrix:
return left.getCols() == right.getRows();
case EOpMatrixTimesMatrixAssign:
// We need to check two things:
// 1. The matrix multiplication step is valid.
// 2. The result will have the same number of columns as the lvalue.
return left.getCols() == right.getRows() && left.getCols() == right.getCols();
default:
UNREACHABLE();
return false;
}
}
TConstantUnion *Vectorize(const TConstantUnion &constant, size_t size) TConstantUnion *Vectorize(const TConstantUnion &constant, size_t size)
{ {
TConstantUnion *constUnion = new TConstantUnion[size]; TConstantUnion *constUnion = new TConstantUnion[size];
...@@ -513,6 +478,94 @@ bool TIntermOperator::isConstructor() const ...@@ -513,6 +478,94 @@ bool TIntermOperator::isConstructor() const
} }
} }
TOperator TIntermBinary::GetMulOpBasedOnOperands(const TType &left, const TType &right)
{
if (left.isMatrix())
{
if (right.isMatrix())
{
return EOpMatrixTimesMatrix;
}
else
{
if (right.isVector())
{
return EOpMatrixTimesVector;
}
else
{
return EOpMatrixTimesScalar;
}
}
}
else
{
if (right.isMatrix())
{
if (left.isVector())
{
return EOpVectorTimesMatrix;
}
else
{
return EOpMatrixTimesScalar;
}
}
else
{
// Neither operand is a matrix.
if (left.isVector() == right.isVector())
{
// Leave as component product.
return EOpMul;
}
else
{
return EOpVectorTimesScalar;
}
}
}
}
TOperator TIntermBinary::GetMulAssignOpBasedOnOperands(const TType &left, const TType &right)
{
if (left.isMatrix())
{
if (right.isMatrix())
{
return EOpMatrixTimesMatrixAssign;
}
else
{
// right should be scalar, but this may not be validated yet.
return EOpMatrixTimesScalarAssign;
}
}
else
{
if (right.isMatrix())
{
// Left should be a vector, but this may not be validated yet.
return EOpVectorTimesMatrixAssign;
}
else
{
// Neither operand is a matrix.
if (left.isVector() == right.isVector())
{
// Leave as component product.
return EOpMulAssign;
}
else
{
// left should be vector and right should be scalar, but this may not be validated
// yet.
return EOpVectorTimesScalarAssign;
}
}
}
}
// //
// Make sure the type of a unary operator is appropriate for its // Make sure the type of a unary operator is appropriate for its
// combination of operation and operand type. // combination of operation and operand type.
...@@ -570,6 +623,9 @@ bool TIntermBinary::promote() ...@@ -570,6 +623,9 @@ bool TIntermBinary::promote()
{ {
ASSERT(mLeft->isArray() == mRight->isArray()); ASSERT(mLeft->isArray() == mRight->isArray());
ASSERT(!isMultiplication() ||
mOp == GetMulOpBasedOnOperands(mLeft->getType(), mRight->getType()));
// //
// Base assumption: just make the type the same as the left // Base assumption: just make the type the same as the left
// operand. Then only deviations from this need be coded. // operand. Then only deviations from this need be coded.
...@@ -633,127 +689,43 @@ bool TIntermBinary::promote() ...@@ -633,127 +689,43 @@ bool TIntermBinary::promote()
// Can these two operands be combined? // Can these two operands be combined?
// //
TBasicType basicType = mLeft->getBasicType(); TBasicType basicType = mLeft->getBasicType();
switch (mOp) switch (mOp)
{ {
case EOpMul: case EOpMul:
if (!mLeft->isMatrix() && mRight->isMatrix()) break;
{ case EOpMatrixTimesScalar:
if (mLeft->isVector()) if (mRight->isMatrix())
{
mOp = EOpVectorTimesMatrix;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()), 1));
}
else
{ {
mOp = EOpMatrixTimesScalar;
setType(TType(basicType, higherPrecision, resultQualifier, setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()), static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mRight->getRows()))); static_cast<unsigned char>(mRight->getRows())));
} }
} break;
else if (mLeft->isMatrix() && !mRight->isMatrix()) case EOpMatrixTimesVector:
{
if (mRight->isVector())
{
mOp = EOpMatrixTimesVector;
setType(TType(basicType, higherPrecision, resultQualifier, setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mLeft->getRows()), 1)); static_cast<unsigned char>(mLeft->getRows()), 1));
} break;
else case EOpMatrixTimesMatrix:
{
mOp = EOpMatrixTimesScalar;
}
}
else if (mLeft->isMatrix() && mRight->isMatrix())
{
mOp = EOpMatrixTimesMatrix;
setType(TType(basicType, higherPrecision, resultQualifier, setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()), static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mLeft->getRows()))); static_cast<unsigned char>(mLeft->getRows())));
} break;
else if (!mLeft->isMatrix() && !mRight->isMatrix()) case EOpVectorTimesScalar:
{
if (mLeft->isVector() && mRight->isVector())
{
// leave as component product
}
else if (mLeft->isVector() || mRight->isVector())
{
mOp = EOpVectorTimesScalar;
setType(TType(basicType, higherPrecision, resultQualifier, setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(nominalSize), 1)); static_cast<unsigned char>(nominalSize), 1));
}
}
else
{
UNREACHABLE();
return false;
}
if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
{
return false;
}
break; break;
case EOpVectorTimesMatrix:
case EOpMulAssign:
if (!mLeft->isMatrix() && mRight->isMatrix())
{
if (mLeft->isVector())
{
mOp = EOpVectorTimesMatrixAssign;
}
else
{
return false;
}
}
else if (mLeft->isMatrix() && !mRight->isMatrix())
{
if (mRight->isVector())
{
return false;
}
else
{
mOp = EOpMatrixTimesScalarAssign;
}
}
else if (mLeft->isMatrix() && mRight->isMatrix())
{
mOp = EOpMatrixTimesMatrixAssign;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mLeft->getRows())));
}
else if (!mLeft->isMatrix() && !mRight->isMatrix())
{
if (mLeft->isVector() && mRight->isVector())
{
// leave as component product
}
else if (mLeft->isVector() || mRight->isVector())
{
if (!mLeft->isVector())
return false;
mOp = EOpVectorTimesScalarAssign;
setType(TType(basicType, higherPrecision, resultQualifier, setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mLeft->getNominalSize()), 1)); static_cast<unsigned char>(mRight->getCols()), 1));
} break;
} case EOpMulAssign:
else case EOpVectorTimesScalarAssign:
{ case EOpVectorTimesMatrixAssign:
UNREACHABLE(); case EOpMatrixTimesScalarAssign:
return false; case EOpMatrixTimesMatrixAssign:
} ASSERT(mOp == GetMulAssignOpBasedOnOperands(mLeft->getType(), mRight->getType()));
if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
{
return false;
}
break; break;
case EOpAssign: case EOpAssign:
case EOpInitialize: case EOpInitialize:
// No more additional checks are needed. // No more additional checks are needed.
...@@ -798,15 +770,13 @@ bool TIntermBinary::promote() ...@@ -798,15 +770,13 @@ bool TIntermBinary::promote()
// would be assigned to a scalar. A scalar can't be shifted by a // would be assigned to a scalar. A scalar can't be shifted by a
// vector either. // vector either.
if (!mRight->isScalar() && if (!mRight->isScalar() &&
(isAssignment() || (isAssignment() || mOp == EOpBitShiftLeft || mOp == EOpBitShiftRight))
mOp == EOpBitShiftLeft ||
mOp == EOpBitShiftRight))
return false; return false;
} }
{ {
const int secondarySize = std::max( const int secondarySize =
mLeft->getSecondarySize(), mRight->getSecondarySize()); std::max(mLeft->getSecondarySize(), mRight->getSecondarySize());
setType(TType(basicType, higherPrecision, resultQualifier, setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(nominalSize), static_cast<unsigned char>(nominalSize),
static_cast<unsigned char>(secondarySize))); static_cast<unsigned char>(secondarySize)));
......
...@@ -417,6 +417,9 @@ class TIntermBinary : public TIntermOperator ...@@ -417,6 +417,9 @@ class TIntermBinary : public TIntermOperator
TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); } TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); }
static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right);
static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
TIntermBinary *getAsBinaryNode() override { return this; }; TIntermBinary *getAsBinaryNode() override { return this; };
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override; bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
......
...@@ -3574,6 +3574,49 @@ bool TParseContext::binaryOpCommonCheck(TOperator op, ...@@ -3574,6 +3574,49 @@ bool TParseContext::binaryOpCommonCheck(TOperator op,
return true; return true;
} }
bool TParseContext::isMultiplicationTypeCombinationValid(TOperator op,
const TType &left,
const TType &right)
{
switch (op)
{
case EOpMul:
case EOpMulAssign:
return left.getNominalSize() == right.getNominalSize() &&
left.getSecondarySize() == right.getSecondarySize();
case EOpVectorTimesScalar:
return true;
case EOpVectorTimesScalarAssign:
ASSERT(!left.isMatrix() && !right.isMatrix());
return left.isVector() && !right.isVector();
case EOpVectorTimesMatrix:
return left.getNominalSize() == right.getRows();
case EOpVectorTimesMatrixAssign:
ASSERT(!left.isMatrix() && right.isMatrix());
return left.isVector() && left.getNominalSize() == right.getRows() &&
left.getNominalSize() == right.getCols();
case EOpMatrixTimesVector:
return left.getCols() == right.getNominalSize();
case EOpMatrixTimesScalar:
return true;
case EOpMatrixTimesScalarAssign:
ASSERT(left.isMatrix() && !right.isMatrix());
return !right.isVector();
case EOpMatrixTimesMatrix:
return left.getCols() == right.getRows();
case EOpMatrixTimesMatrixAssign:
ASSERT(left.isMatrix() && right.isMatrix());
// We need to check two things:
// 1. The matrix multiplication step is valid.
// 2. The result will have the same number of columns as the lvalue.
return left.getCols() == right.getRows() && left.getCols() == right.getCols();
default:
UNREACHABLE();
return false;
}
}
TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op, TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
TIntermTyped *left, TIntermTyped *left,
TIntermTyped *right, TIntermTyped *right,
...@@ -3634,6 +3677,15 @@ TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op, ...@@ -3634,6 +3677,15 @@ TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
break; break;
} }
if (op == EOpMul)
{
op = TIntermBinary::GetMulOpBasedOnOperands(left->getType(), right->getType());
if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
{
return nullptr;
}
}
TIntermBinary *node = new TIntermBinary(op, left, right); TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc); node->setLine(loc);
...@@ -3688,6 +3740,14 @@ TIntermTyped *TParseContext::createAssign(TOperator op, ...@@ -3688,6 +3740,14 @@ TIntermTyped *TParseContext::createAssign(TOperator op,
{ {
if (binaryOpCommonCheck(op, left, right, loc)) if (binaryOpCommonCheck(op, left, right, loc))
{ {
if (op == EOpMulAssign)
{
op = TIntermBinary::GetMulAssignOpBasedOnOperands(left->getType(), right->getType());
if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
{
return nullptr;
}
}
TIntermBinary *node = new TIntermBinary(op, left, right); TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc); node->setLine(loc);
......
...@@ -389,6 +389,9 @@ class TParseContext : angle::NonCopyable ...@@ -389,6 +389,9 @@ class TParseContext : angle::NonCopyable
bool checkIsValidTypeAndQualifierForArray(const TSourceLoc &indexLocation, bool checkIsValidTypeAndQualifierForArray(const TSourceLoc &indexLocation,
const TPublicType &elementType); const TPublicType &elementType);
// Assumes that multiplication op has already been set based on the types.
bool isMultiplicationTypeCombinationValid(TOperator op, const TType &left, const TType &right);
TIntermTyped *addBinaryMathInternal( TIntermTyped *addBinaryMathInternal(
TOperator op, TIntermTyped *left, TIntermTyped *right, const TSourceLoc &loc); TOperator op, TIntermTyped *left, TIntermTyped *right, const TSourceLoc &loc);
TIntermTyped *createAssign( TIntermTyped *createAssign(
......
...@@ -60,7 +60,8 @@ bool RemovePowTraverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -60,7 +60,8 @@ bool RemovePowTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
log->setLine(node->getLine()); log->setLine(node->getLine());
log->setType(x->getType()); log->setType(x->getType());
TIntermBinary *mul = new TIntermBinary(EOpMul, y, log); TOperator op = TIntermBinary::GetMulOpBasedOnOperands(y->getType(), log->getType());
TIntermBinary *mul = new TIntermBinary(op, y, log);
mul->setLine(node->getLine()); mul->setLine(node->getLine());
bool valid = mul->promote(); bool valid = mul->promote();
UNUSED_ASSERTION_VARIABLE(valid); UNUSED_ASSERTION_VARIABLE(valid);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment