Commit 1dded803 by Olli Etuaho Committed by Commit Bot

Check multiplication validity in ParseContext

This improves separation of responsibilities in the code: ParseContext should handle operand type validation, while TIntermBinary::promote should ideally only determine the type of the node based on the operation and operands. BUG=angleproject:952 TEST=angle_unittests Change-Id: I9a8d8ede21cdf35de631623a62194c0da5c604d2 Reviewed-on: https://chromium-review.googlesource.com/372622 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent 9949f04d
......@@ -34,41 +34,6 @@ TPrecision GetHigherPrecision(TPrecision left, TPrecision right)
return left > right ? left : right;
}
bool ValidateMultiplication(TOperator op, const TType &left, const TType &right)
{
switch (op)
{
case EOpMul:
case EOpMulAssign:
return left.getNominalSize() == right.getNominalSize() &&
left.getSecondarySize() == right.getSecondarySize();
case EOpVectorTimesScalar:
case EOpVectorTimesScalarAssign:
return true;
case EOpVectorTimesMatrix:
return left.getNominalSize() == right.getRows();
case EOpVectorTimesMatrixAssign:
return left.getNominalSize() == right.getRows() &&
left.getNominalSize() == right.getCols();
case EOpMatrixTimesVector:
return left.getCols() == right.getNominalSize();
case EOpMatrixTimesScalar:
case EOpMatrixTimesScalarAssign:
return true;
case EOpMatrixTimesMatrix:
return left.getCols() == right.getRows();
case EOpMatrixTimesMatrixAssign:
// We need to check two things:
// 1. The matrix multiplication step is valid.
// 2. The result will have the same number of columns as the lvalue.
return left.getCols() == right.getRows() && left.getCols() == right.getCols();
default:
UNREACHABLE();
return false;
}
}
TConstantUnion *Vectorize(const TConstantUnion &constant, size_t size)
{
TConstantUnion *constUnion = new TConstantUnion[size];
......@@ -513,6 +478,94 @@ bool TIntermOperator::isConstructor() const
}
}
TOperator TIntermBinary::GetMulOpBasedOnOperands(const TType &left, const TType &right)
{
if (left.isMatrix())
{
if (right.isMatrix())
{
return EOpMatrixTimesMatrix;
}
else
{
if (right.isVector())
{
return EOpMatrixTimesVector;
}
else
{
return EOpMatrixTimesScalar;
}
}
}
else
{
if (right.isMatrix())
{
if (left.isVector())
{
return EOpVectorTimesMatrix;
}
else
{
return EOpMatrixTimesScalar;
}
}
else
{
// Neither operand is a matrix.
if (left.isVector() == right.isVector())
{
// Leave as component product.
return EOpMul;
}
else
{
return EOpVectorTimesScalar;
}
}
}
}
TOperator TIntermBinary::GetMulAssignOpBasedOnOperands(const TType &left, const TType &right)
{
if (left.isMatrix())
{
if (right.isMatrix())
{
return EOpMatrixTimesMatrixAssign;
}
else
{
// right should be scalar, but this may not be validated yet.
return EOpMatrixTimesScalarAssign;
}
}
else
{
if (right.isMatrix())
{
// Left should be a vector, but this may not be validated yet.
return EOpVectorTimesMatrixAssign;
}
else
{
// Neither operand is a matrix.
if (left.isVector() == right.isVector())
{
// Leave as component product.
return EOpMulAssign;
}
else
{
// left should be vector and right should be scalar, but this may not be validated
// yet.
return EOpVectorTimesScalarAssign;
}
}
}
}
//
// Make sure the type of a unary operator is appropriate for its
// combination of operation and operand type.
......@@ -570,6 +623,9 @@ bool TIntermBinary::promote()
{
ASSERT(mLeft->isArray() == mRight->isArray());
ASSERT(!isMultiplication() ||
mOp == GetMulOpBasedOnOperands(mLeft->getType(), mRight->getType()));
//
// Base assumption: just make the type the same as the left
// operand. Then only deviations from this need be coded.
......@@ -633,204 +689,118 @@ bool TIntermBinary::promote()
// Can these two operands be combined?
//
TBasicType basicType = mLeft->getBasicType();
switch (mOp)
{
case EOpMul:
if (!mLeft->isMatrix() && mRight->isMatrix())
{
if (mLeft->isVector())
{
mOp = EOpVectorTimesMatrix;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()), 1));
}
else
case EOpMul:
break;
case EOpMatrixTimesScalar:
if (mRight->isMatrix())
{
mOp = EOpMatrixTimesScalar;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mRight->getRows())));
}
}
else if (mLeft->isMatrix() && !mRight->isMatrix())
{
if (mRight->isVector())
{
mOp = EOpMatrixTimesVector;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mLeft->getRows()), 1));
}
else
{
mOp = EOpMatrixTimesScalar;
}
}
else if (mLeft->isMatrix() && mRight->isMatrix())
{
mOp = EOpMatrixTimesMatrix;
break;
case EOpMatrixTimesVector:
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mLeft->getRows())));
}
else if (!mLeft->isMatrix() && !mRight->isMatrix())
{
if (mLeft->isVector() && mRight->isVector())
{
// leave as component product
}
else if (mLeft->isVector() || mRight->isVector())
{
mOp = EOpVectorTimesScalar;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(nominalSize), 1));
}
}
else
{
UNREACHABLE();
return false;
}
if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
{
return false;
}
break;
case EOpMulAssign:
if (!mLeft->isMatrix() && mRight->isMatrix())
{
if (mLeft->isVector())
{
mOp = EOpVectorTimesMatrixAssign;
}
else
{
return false;
}
}
else if (mLeft->isMatrix() && !mRight->isMatrix())
{
if (mRight->isVector())
{
return false;
}
else
{
mOp = EOpMatrixTimesScalarAssign;
}
}
else if (mLeft->isMatrix() && mRight->isMatrix())
{
mOp = EOpMatrixTimesMatrixAssign;
static_cast<unsigned char>(mLeft->getRows()), 1));
break;
case EOpMatrixTimesMatrix:
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mLeft->getRows())));
}
else if (!mLeft->isMatrix() && !mRight->isMatrix())
{
if (mLeft->isVector() && mRight->isVector())
break;
case EOpVectorTimesScalar:
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(nominalSize), 1));
break;
case EOpVectorTimesMatrix:
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()), 1));
break;
case EOpMulAssign:
case EOpVectorTimesScalarAssign:
case EOpVectorTimesMatrixAssign:
case EOpMatrixTimesScalarAssign:
case EOpMatrixTimesMatrixAssign:
ASSERT(mOp == GetMulAssignOpBasedOnOperands(mLeft->getType(), mRight->getType()));
break;
case EOpAssign:
case EOpInitialize:
// No more additional checks are needed.
ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
(mLeft->getSecondarySize() == mRight->getSecondarySize()));
break;
case EOpAdd:
case EOpSub:
case EOpDiv:
case EOpIMod:
case EOpBitShiftLeft:
case EOpBitShiftRight:
case EOpBitwiseAnd:
case EOpBitwiseXor:
case EOpBitwiseOr:
case EOpAddAssign:
case EOpSubAssign:
case EOpDivAssign:
case EOpIModAssign:
case EOpBitShiftLeftAssign:
case EOpBitShiftRightAssign:
case EOpBitwiseAndAssign:
case EOpBitwiseXorAssign:
case EOpBitwiseOrAssign:
if ((mLeft->isMatrix() && mRight->isVector()) ||
(mLeft->isVector() && mRight->isMatrix()))
{
// leave as component product
return false;
}
else if (mLeft->isVector() || mRight->isVector())
// Are the sizes compatible?
if (mLeft->getNominalSize() != mRight->getNominalSize() ||
mLeft->getSecondarySize() != mRight->getSecondarySize())
{
if (!mLeft->isVector())
// If the nominal sizes of operands do not match:
// One of them must be a scalar.
if (!mLeft->isScalar() && !mRight->isScalar())
return false;
mOp = EOpVectorTimesScalarAssign;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mLeft->getNominalSize()), 1));
}
}
else
{
UNREACHABLE();
return false;
}
if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
{
return false;
}
break;
case EOpAssign:
case EOpInitialize:
// No more additional checks are needed.
ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
(mLeft->getSecondarySize() == mRight->getSecondarySize()));
break;
case EOpAdd:
case EOpSub:
case EOpDiv:
case EOpIMod:
case EOpBitShiftLeft:
case EOpBitShiftRight:
case EOpBitwiseAnd:
case EOpBitwiseXor:
case EOpBitwiseOr:
case EOpAddAssign:
case EOpSubAssign:
case EOpDivAssign:
case EOpIModAssign:
case EOpBitShiftLeftAssign:
case EOpBitShiftRightAssign:
case EOpBitwiseAndAssign:
case EOpBitwiseXorAssign:
case EOpBitwiseOrAssign:
if ((mLeft->isMatrix() && mRight->isVector()) ||
(mLeft->isVector() && mRight->isMatrix()))
{
return false;
}
// Are the sizes compatible?
if (mLeft->getNominalSize() != mRight->getNominalSize() ||
mLeft->getSecondarySize() != mRight->getSecondarySize())
{
// If the nominal sizes of operands do not match:
// One of them must be a scalar.
if (!mLeft->isScalar() && !mRight->isScalar())
return false;
// In the case of compound assignment other than multiply-assign,
// the right side needs to be a scalar. Otherwise a vector/matrix
// would be assigned to a scalar. A scalar can't be shifted by a
// vector either.
if (!mRight->isScalar() &&
(isAssignment() ||
mOp == EOpBitShiftLeft ||
mOp == EOpBitShiftRight))
return false;
}
// In the case of compound assignment other than multiply-assign,
// the right side needs to be a scalar. Otherwise a vector/matrix
// would be assigned to a scalar. A scalar can't be shifted by a
// vector either.
if (!mRight->isScalar() &&
(isAssignment() || mOp == EOpBitShiftLeft || mOp == EOpBitShiftRight))
return false;
}
{
const int secondarySize = std::max(
mLeft->getSecondarySize(), mRight->getSecondarySize());
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(nominalSize),
static_cast<unsigned char>(secondarySize)));
if (mLeft->isArray())
{
ASSERT(mLeft->getArraySize() == mRight->getArraySize());
mType.setArraySize(mLeft->getArraySize());
const int secondarySize =
std::max(mLeft->getSecondarySize(), mRight->getSecondarySize());
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(nominalSize),
static_cast<unsigned char>(secondarySize)));
if (mLeft->isArray())
{
ASSERT(mLeft->getArraySize() == mRight->getArraySize());
mType.setArraySize(mLeft->getArraySize());
}
}
}
break;
break;
case EOpEqual:
case EOpNotEqual:
case EOpLessThan:
case EOpGreaterThan:
case EOpLessThanEqual:
case EOpGreaterThanEqual:
ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
(mLeft->getSecondarySize() == mRight->getSecondarySize()));
setType(TType(EbtBool, EbpUndefined));
break;
case EOpEqual:
case EOpNotEqual:
case EOpLessThan:
case EOpGreaterThan:
case EOpLessThanEqual:
case EOpGreaterThanEqual:
ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
(mLeft->getSecondarySize() == mRight->getSecondarySize()));
setType(TType(EbtBool, EbpUndefined));
break;
default:
return false;
default:
return false;
}
return true;
}
......
......@@ -417,6 +417,9 @@ class TIntermBinary : public TIntermOperator
TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); }
static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right);
static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
TIntermBinary *getAsBinaryNode() override { return this; };
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
......
......@@ -3574,6 +3574,49 @@ bool TParseContext::binaryOpCommonCheck(TOperator op,
return true;
}
bool TParseContext::isMultiplicationTypeCombinationValid(TOperator op,
const TType &left,
const TType &right)
{
switch (op)
{
case EOpMul:
case EOpMulAssign:
return left.getNominalSize() == right.getNominalSize() &&
left.getSecondarySize() == right.getSecondarySize();
case EOpVectorTimesScalar:
return true;
case EOpVectorTimesScalarAssign:
ASSERT(!left.isMatrix() && !right.isMatrix());
return left.isVector() && !right.isVector();
case EOpVectorTimesMatrix:
return left.getNominalSize() == right.getRows();
case EOpVectorTimesMatrixAssign:
ASSERT(!left.isMatrix() && right.isMatrix());
return left.isVector() && left.getNominalSize() == right.getRows() &&
left.getNominalSize() == right.getCols();
case EOpMatrixTimesVector:
return left.getCols() == right.getNominalSize();
case EOpMatrixTimesScalar:
return true;
case EOpMatrixTimesScalarAssign:
ASSERT(left.isMatrix() && !right.isMatrix());
return !right.isVector();
case EOpMatrixTimesMatrix:
return left.getCols() == right.getRows();
case EOpMatrixTimesMatrixAssign:
ASSERT(left.isMatrix() && right.isMatrix());
// We need to check two things:
// 1. The matrix multiplication step is valid.
// 2. The result will have the same number of columns as the lvalue.
return left.getCols() == right.getRows() && left.getCols() == right.getCols();
default:
UNREACHABLE();
return false;
}
}
TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
TIntermTyped *left,
TIntermTyped *right,
......@@ -3634,6 +3677,15 @@ TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
break;
}
if (op == EOpMul)
{
op = TIntermBinary::GetMulOpBasedOnOperands(left->getType(), right->getType());
if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
{
return nullptr;
}
}
TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc);
......@@ -3688,6 +3740,14 @@ TIntermTyped *TParseContext::createAssign(TOperator op,
{
if (binaryOpCommonCheck(op, left, right, loc))
{
if (op == EOpMulAssign)
{
op = TIntermBinary::GetMulAssignOpBasedOnOperands(left->getType(), right->getType());
if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
{
return nullptr;
}
}
TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc);
......
......@@ -389,6 +389,9 @@ class TParseContext : angle::NonCopyable
bool checkIsValidTypeAndQualifierForArray(const TSourceLoc &indexLocation,
const TPublicType &elementType);
// Assumes that multiplication op has already been set based on the types.
bool isMultiplicationTypeCombinationValid(TOperator op, const TType &left, const TType &right);
TIntermTyped *addBinaryMathInternal(
TOperator op, TIntermTyped *left, TIntermTyped *right, const TSourceLoc &loc);
TIntermTyped *createAssign(
......
......@@ -60,7 +60,8 @@ bool RemovePowTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
log->setLine(node->getLine());
log->setType(x->getType());
TIntermBinary *mul = new TIntermBinary(EOpMul, y, log);
TOperator op = TIntermBinary::GetMulOpBasedOnOperands(y->getType(), log->getType());
TIntermBinary *mul = new TIntermBinary(op, y, log);
mul->setLine(node->getLine());
bool valid = mul->promote();
UNUSED_ASSERTION_VARIABLE(valid);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment