Commit 56229f1b by Olli Etuaho

Remove TIntermediate::addConstantUnion

This includes asserts in TConstantUnion to reveal incorrect usage of union - reading a different field of an union that has last been set is undefined behavior in C++. Existing issues with accessing incorrect fields of constant unions are fixed. BUG=angleproject:1490 TEST=angle_unittests Change-Id: Idd6b7a871d73e2928f117a9348c92043612fab82
parent eb7f90fd
......@@ -72,6 +72,36 @@ TConstantUnion::TConstantUnion()
type = EbtVoid;
}
int TConstantUnion::getIConst() const
{
ASSERT(type == EbtInt);
return iConst;
}
unsigned int TConstantUnion::getUConst() const
{
ASSERT(type == EbtUInt);
return uConst;
}
float TConstantUnion::getFConst() const
{
ASSERT(type == EbtFloat);
return fConst;
}
bool TConstantUnion::getBConst() const
{
ASSERT(type == EbtBool);
return bConst;
}
TYuvCscStandardEXT TConstantUnion::getYuvCscStandardEXTConst() const
{
ASSERT(type == EbtYuvCscStandardEXT);
return yuvCscStandardEXTConst;
}
bool TConstantUnion::cast(TBasicType newType, const TConstantUnion &constant)
{
switch (newType)
......
......@@ -52,11 +52,11 @@ class TConstantUnion
type = EbtYuvCscStandardEXT;
}
int getIConst() const { return iConst; }
unsigned int getUConst() const { return uConst; }
float getFConst() const { return fConst; }
bool getBConst() const { return bConst; }
TYuvCscStandardEXT getYuvCscStandardEXTConst() const { return yuvCscStandardEXTConst; }
int getIConst() const;
unsigned int getUConst() const;
float getFConst() const;
bool getBConst() const;
TYuvCscStandardEXT getYuvCscStandardEXTConst() const;
bool operator==(const int i) const;
bool operator==(const unsigned int u) const;
......
......@@ -43,20 +43,4 @@ TIntermBlock *TIntermediate::EnsureBlock(TIntermNode *node)
return blockNode;
}
//
// Constant terminal nodes. Has a union that contains bool, float or int constants
//
// Returns the constant union node created.
//
TIntermConstantUnion *TIntermediate::addConstantUnion(const TConstantUnion *constantUnion,
const TType &type,
const TSourceLoc &line)
{
TIntermConstantUnion *node = new TIntermConstantUnion(constantUnion, type);
node->setLine(line);
return node;
}
} // namespace sh
......@@ -14,20 +14,15 @@ namespace sh
//
// Set of helper functions to help build the tree.
// TODO(oetuaho@nvidia.com): Clean this up, it doesn't need to be a class.
//
class TIntermediate
{
public:
POOL_ALLOCATOR_NEW_DELETE();
TIntermediate() {}
static TIntermBlock *EnsureBlock(TIntermNode *node);
TIntermConstantUnion *addConstantUnion(const TConstantUnion *constantUnion,
const TType &type,
const TSourceLoc &line);
private:
void operator=(TIntermediate &); // prevent assignments
TIntermediate(){};
};
} // namespace sh
......
......@@ -128,8 +128,7 @@ TParseContext::TParseContext(TSymbolTable &symt,
bool checksPrecErrors,
TDiagnostics *diagnostics,
const ShBuiltInResources &resources)
: intermediate(),
symbolTable(symt),
: symbolTable(symt),
mDeferredNonEmptyDeclarationErrorCheck(false),
mShaderType(type),
mShaderSpec(spec),
......@@ -764,12 +763,14 @@ bool TParseContext::checkIsNonVoid(const TSourceLoc &line,
// This function checks to see if the node (for the expression) contains a scalar boolean expression
// or not.
void TParseContext::checkIsScalarBool(const TSourceLoc &line, const TIntermTyped *type)
bool TParseContext::checkIsScalarBool(const TSourceLoc &line, const TIntermTyped *type)
{
if (type->getBasicType() != EbtBool || type->isArray() || type->isMatrix() || type->isVector())
{
error(line, "boolean expression expected", "");
return false;
}
return true;
}
// This function checks to see if the node (for the expression) contains a scalar boolean expression
......@@ -1558,6 +1559,15 @@ sh::WorkGroupSize TParseContext::getComputeShaderLocalSize() const
return result;
}
TIntermConstantUnion *TParseContext::addScalarLiteral(const TConstantUnion *constantUnion,
const TSourceLoc &line)
{
TIntermConstantUnion *node = new TIntermConstantUnion(
constantUnion, TType(constantUnion->getType(), EbpUndefined, EvqConst));
node->setLine(line);
return node;
}
/////////////////////////////////////////////////////////////////////////////////
//
// Non-Errors.
......@@ -1654,10 +1664,12 @@ TIntermTyped *TParseContext::parseVariableIdentifier(const TSourceLoc &location,
"gl_ViewID_OVR");
}
TIntermTyped *node = nullptr;
if (variable->getConstPointer())
{
const TConstantUnion *constArray = variable->getConstPointer();
return intermediate.addConstantUnion(constArray, variable->getType(), location);
node = new TIntermConstantUnion(constArray, variable->getType());
}
else if (variable->getType().getQualifier() == EvqWorkGroupSize &&
mComputeShaderLocalSizeDeclared)
......@@ -1676,15 +1688,15 @@ TIntermTyped *TParseContext::parseVariableIdentifier(const TSourceLoc &location,
TType type(variable->getType());
type.setQualifier(EvqConst);
return intermediate.addConstantUnion(constArray, type, location);
node = new TIntermConstantUnion(constArray, type);
}
else
{
TIntermSymbol *symbolNode =
new TIntermSymbol(variable->getUniqueId(), variable->getName(), variable->getType());
symbolNode->setLine(location);
return symbolNode;
node = new TIntermSymbol(variable->getUniqueId(), variable->getName(), variable->getType());
}
ASSERT(node != nullptr);
node->setLine(location);
return node;
}
// Initializers show up in several places in the grammar. Have one set of
......@@ -1898,10 +1910,10 @@ TIntermNode *TParseContext::addIfElse(TIntermTyped *cond,
TIntermNodePair code,
const TSourceLoc &loc)
{
checkIsScalarBool(loc, cond);
bool isScalarBool = checkIsScalarBool(loc, cond);
// For compile time constant conditions, prune the code now.
if (cond->getAsConstantUnion())
if (isScalarBool && cond->getAsConstantUnion())
{
if (cond->getAsConstantUnion()->getBConst(0) == true)
{
......@@ -3338,10 +3350,7 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
error(location, " left of '[' is not of type array, matrix, or vector ", "expression");
}
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setFConst(0.0f);
return intermediate.addConstantUnion(unionArray, TType(EbtFloat, EbpHigh, EvqConst),
location);
return TIntermTyped::CreateZero(TType(EbtFloat, EbpHigh, EvqConst));
}
TIntermConstantUnion *indexConstantUnion = indexExpression->getAsConstantUnion();
......@@ -3377,7 +3386,15 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
// handle this case is to report a warning instead of an error and force the index to be in
// the correct range.
bool outOfRangeIndexIsError = indexExpression->getQualifier() == EvqConst;
int index = indexConstantUnion->getIConst(0);
int index = 0;
if (indexConstantUnion->getBasicType() == EbtInt)
{
index = indexConstantUnion->getIConst(0);
}
else if (indexConstantUnion->getBasicType() == EbtUInt)
{
index = static_cast<int>(indexConstantUnion->getUConst(0));
}
int safeIndex = -1;
......@@ -3428,11 +3445,12 @@ TIntermTyped *TParseContext::addIndexExpression(TIntermTyped *baseExpression,
// Data of constant unions can't be changed, because it may be shared with other
// constant unions or even builtins, like gl_MaxDrawBuffers. Instead use a new
// sanitized object.
if (safeIndex != index)
if (safeIndex != index || indexConstantUnion->getBasicType() != EbtInt)
{
TConstantUnion *safeConstantUnion = new TConstantUnion();
safeConstantUnion->setIConst(safeIndex);
indexConstantUnion->replaceConstantUnion(safeConstantUnion);
indexConstantUnion->getTypePointer()->setBasicType(EbtInt);
}
TIntermBinary *node = new TIntermBinary(EOpIndexDirect, baseExpression, indexExpression);
......@@ -4555,14 +4573,12 @@ TIntermTyped *TParseContext::addBinaryMathBooleanResult(TOperator op,
const TSourceLoc &loc)
{
TIntermTyped *node = addBinaryMathInternal(op, left, right, loc);
if (node == 0)
if (node == nullptr)
{
binaryOpError(loc, GetOperatorString(op), left->getCompleteString(),
right->getCompleteString());
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setBConst(false);
return intermediate.addConstantUnion(unionArray, TType(EbtBool, EbpUndefined, EvqConst),
loc);
node = TIntermTyped::CreateZero(TType(EbtBool, EbpUndefined, EvqConst));
node->setLine(loc);
}
return node;
}
......@@ -4894,7 +4910,10 @@ TIntermTyped *TParseContext::addMethod(TFunction *fnCall,
}
}
unionArray->setIConst(arraySize);
return intermediate.addConstantUnion(unionArray, TType(EbtInt, EbpUndefined, EvqConst), loc);
TIntermConstantUnion *node =
new TIntermConstantUnion(unionArray, TType(EbtInt, EbpUndefined, EvqConst));
node->setLine(loc);
return node;
}
TIntermTyped *TParseContext::addNonConstructorFunctionCall(TFunction *fnCall,
......@@ -5000,7 +5019,10 @@ TIntermTyped *TParseContext::addTernarySelection(TIntermTyped *cond,
TIntermTyped *falseExpression,
const TSourceLoc &loc)
{
checkIsScalarBool(loc, cond);
if (!checkIsScalarBool(loc, cond))
{
return falseExpression;
}
if (trueExpression->getType() != falseExpression->getType())
{
......
......@@ -90,6 +90,9 @@ class TParseContext : angle::NonCopyable
bool declaringFunction() const { return mDeclaringFunction; }
TIntermConstantUnion *addScalarLiteral(const TConstantUnion *constantUnion,
const TSourceLoc &line);
// This method is guaranteed to succeed, even if no variable with 'name' exists.
const TVariable *getNamedVariable(const TSourceLoc &location,
const TString *name,
......@@ -125,7 +128,7 @@ class TParseContext : angle::NonCopyable
bool checkIsValidQualifierForArray(const TSourceLoc &line, const TPublicType &elementQualifier);
bool checkIsValidTypeForArray(const TSourceLoc &line, const TPublicType &elementType);
bool checkIsNonVoid(const TSourceLoc &line, const TString &identifier, const TBasicType &type);
void checkIsScalarBool(const TSourceLoc &line, const TIntermTyped *type);
bool checkIsScalarBool(const TSourceLoc &line, const TIntermTyped *type);
void checkIsScalarBool(const TSourceLoc &line, const TPublicType &pType);
bool checkIsNotOpaqueType(const TSourceLoc &line,
const TTypeSpecifierNonArray &pType,
......@@ -405,8 +408,7 @@ class TParseContext : angle::NonCopyable
TIntermTyped *falseExpression,
const TSourceLoc &line);
// TODO(jmadill): make these private
TIntermediate intermediate; // to build a parse tree
// TODO(jmadill): make this private
TSymbolTable &symbolTable; // symbol table that goes with the language currently being parsed
private:
......
......@@ -607,7 +607,7 @@ bool CollectVariablesTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
ASSERT(namedBlock);
namedBlock->staticUse = true;
unsigned int fieldIndex = constantUnion->getUConst(0);
unsigned int fieldIndex = static_cast<unsigned int>(constantUnion->getIConst(0));
ASSERT(fieldIndex < namedBlock->fields.size());
namedBlock->fields[fieldIndex].staticUse = true;
return false;
......
......@@ -261,22 +261,22 @@ primary_expression
| INTCONSTANT {
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setIConst($1.i);
$$ = context->intermediate.addConstantUnion(unionArray, TType(EbtInt, EbpUndefined, EvqConst), @1);
$$ = context->addScalarLiteral(unionArray, @1);
}
| UINTCONSTANT {
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setUConst($1.u);
$$ = context->intermediate.addConstantUnion(unionArray, TType(EbtUInt, EbpUndefined, EvqConst), @1);
$$ = context->addScalarLiteral(unionArray, @1);
}
| FLOATCONSTANT {
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setFConst($1.f);
$$ = context->intermediate.addConstantUnion(unionArray, TType(EbtFloat, EbpUndefined, EvqConst), @1);
$$ = context->addScalarLiteral(unionArray, @1);
}
| BOOLCONSTANT {
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setBConst($1.b);
$$ = context->intermediate.addConstantUnion(unionArray, TType(EbtBool, EbpUndefined, EvqConst), @1);
$$ = context->addScalarLiteral(unionArray, @1);
}
| YUVCSCSTANDARDEXTCONSTANT {
if (!context->isExtensionEnabled("GL_EXT_YUV_target")) {
......@@ -284,7 +284,7 @@ primary_expression
}
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setYuvCscStandardEXTConst(getYuvCscStandardEXT($1.string->c_str()));
$$ = context->intermediate.addConstantUnion(unionArray, TType(EbtYuvCscStandardEXT, EbpUndefined, EvqConst), @1);
$$ = context->addScalarLiteral(unionArray, @1);
}
| LEFT_PAREN expression RIGHT_PAREN {
$$ = $2;
......
......@@ -2543,7 +2543,7 @@ yyreduce:
{
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setIConst((yyvsp[0].lex).i);
(yyval.interm.intermTypedNode) = context->intermediate.addConstantUnion(unionArray, TType(EbtInt, EbpUndefined, EvqConst), (yylsp[0]));
(yyval.interm.intermTypedNode) = context->addScalarLiteral(unionArray, (yylsp[0]));
}
break;
......@@ -2553,7 +2553,7 @@ yyreduce:
{
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setUConst((yyvsp[0].lex).u);
(yyval.interm.intermTypedNode) = context->intermediate.addConstantUnion(unionArray, TType(EbtUInt, EbpUndefined, EvqConst), (yylsp[0]));
(yyval.interm.intermTypedNode) = context->addScalarLiteral(unionArray, (yylsp[0]));
}
break;
......@@ -2563,7 +2563,7 @@ yyreduce:
{
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setFConst((yyvsp[0].lex).f);
(yyval.interm.intermTypedNode) = context->intermediate.addConstantUnion(unionArray, TType(EbtFloat, EbpUndefined, EvqConst), (yylsp[0]));
(yyval.interm.intermTypedNode) = context->addScalarLiteral(unionArray, (yylsp[0]));
}
break;
......@@ -2573,7 +2573,7 @@ yyreduce:
{
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setBConst((yyvsp[0].lex).b);
(yyval.interm.intermTypedNode) = context->intermediate.addConstantUnion(unionArray, TType(EbtBool, EbpUndefined, EvqConst), (yylsp[0]));
(yyval.interm.intermTypedNode) = context->addScalarLiteral(unionArray, (yylsp[0]));
}
break;
......@@ -2586,7 +2586,7 @@ yyreduce:
}
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray->setYuvCscStandardEXTConst(getYuvCscStandardEXT((yyvsp[0].lex).string->c_str()));
(yyval.interm.intermTypedNode) = context->intermediate.addConstantUnion(unionArray, TType(EbtYuvCscStandardEXT, EbpUndefined, EvqConst), (yylsp[0]));
(yyval.interm.intermTypedNode) = context->addScalarLiteral(unionArray, (yylsp[0]));
}
break;
......
......@@ -3974,3 +3974,58 @@ TEST_F(FragmentShaderValidationTest, FunctionDefinedWithReservedName)
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
// Test that ops with mismatching operand types are disallowed and don't result in an assert.
// This makes sure that constant folding doesn't fetch invalid union values in case operand types
// mismatch.
TEST_F(FragmentShaderValidationTest, InvalidOpsWithConstantOperandsDontAssert)
{
const std::string &shaderString =
"#version 300 es\n"
"precision mediump float;\n"
"out vec4 my_FragColor;\n"
"void main()\n"
"{\n"
" float f1 = 0.5 / 2;\n"
" float f2 = true + 0.5;\n"
" float f3 = float[2](0.0, 1.0)[1.0];\n"
" float f4 = float[2](0.0, 1.0)[true];\n"
" float f5 = true ? 1.0 : 0;\n"
" float f6 = 1.0 ? 1.0 : 2.0;\n"
" my_FragColor = vec4(0.0);\n"
"}\n";
if (compile(shaderString))
{
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
// Test that case labels with invalid types don't assert
TEST_F(FragmentShaderValidationTest, CaseLabelsWithInvalidTypesDontAssert)
{
const std::string &shaderString =
"#version 300 es\n"
"precision mediump float;\n"
"out vec4 my_FragColor;\n"
"uniform int i;\n"
"void main()\n"
"{\n"
" float f = 0.0;\n"
" switch (i)\n"
" {\n"
" case 0u:\n"
" f = 0.0;\n"
" case true:\n"
" f = 1.0;\n"
" case 2.0:\n"
" f = 2.0;\n"
" }\n"
" my_FragColor = vec4(0.0);\n"
"}\n";
if (compile(shaderString))
{
FAIL() << "Shader compilation succeeded, expecting failure:\n" << mInfoLog;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment