Commit 5db69f57 by Jamie Madill Committed by Commit Bot

Add robust math to constant folding.

Previously our multiplication and other operators could do overflows, which can lead to security bugs. BUG=chromium:637050 Change-Id: Icee22a87909e205b71bda1c5bc1627fcf5e26e90 Reviewed-on: https://chromium-review.googlesource.com/382678 Commit-Queue: Jamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent a5615c69
...@@ -57,7 +57,7 @@ void InitializeDebugAnnotations(DebugAnnotator *debugAnnotator); ...@@ -57,7 +57,7 @@ void InitializeDebugAnnotations(DebugAnnotator *debugAnnotator);
void UninitializeDebugAnnotations(); void UninitializeDebugAnnotations();
bool DebugAnnotationsActive(); bool DebugAnnotationsActive();
} } // namespace gl
#if defined(ANGLE_ENABLE_DEBUG_TRACE) || defined(ANGLE_ENABLE_DEBUG_ANNOTATIONS) #if defined(ANGLE_ENABLE_DEBUG_TRACE) || defined(ANGLE_ENABLE_DEBUG_ANNOTATIONS)
#define ANGLE_TRACE_ENABLED #define ANGLE_TRACE_ENABLED
...@@ -153,10 +153,9 @@ bool DebugAnnotationsActive(); ...@@ -153,10 +153,9 @@ bool DebugAnnotationsActive();
// A macro for code which is not expected to be reached under valid assumptions // A macro for code which is not expected to be reached under valid assumptions
#if !defined(NDEBUG) #if !defined(NDEBUG)
#define UNREACHABLE() { \ #define UNREACHABLE() \
ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__); \ ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__), \
assert(false); \ assert(false)
} ANGLE_EMPTY_STATEMENT
#else #else
#define UNREACHABLE() ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__) #define UNREACHABLE() ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__)
#endif #endif
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
// Unfortunately ANGLE relies on ASSERT being an empty statement, which these libs don't respect. // Unfortunately ANGLE relies on ASSERT being an empty statement, which these libs don't respect.
#ifndef NOTREACHED #ifndef NOTREACHED
#define NOTREACHED() 0 #define NOTREACHED() UNREACHABLE()
#endif #endif
#endif // BASE_LOGGING_H_ #endif // BASE_LOGGING_H_
\ No newline at end of file
...@@ -7,8 +7,62 @@ ...@@ -7,8 +7,62 @@
#include "compiler/translator/ConstantUnion.h" #include "compiler/translator/ConstantUnion.h"
#include "base/numerics/safe_math.h"
#include "compiler/translator/Diagnostics.h" #include "compiler/translator/Diagnostics.h"
namespace
{
template <typename T>
T CheckedSum(base::CheckedNumeric<T> lhs,
base::CheckedNumeric<T> rhs,
TDiagnostics *diag,
const TSourceLoc &line)
{
ASSERT(lhs.IsValid() && rhs.IsValid());
auto result = lhs + rhs;
if (!result.IsValid())
{
diag->error(line, "Addition out of range", "*", "");
return 0;
}
return result.ValueOrDefault(0);
}
template <typename T>
T CheckedDiff(base::CheckedNumeric<T> lhs,
base::CheckedNumeric<T> rhs,
TDiagnostics *diag,
const TSourceLoc &line)
{
ASSERT(lhs.IsValid() && rhs.IsValid());
auto result = lhs - rhs;
if (!result.IsValid())
{
diag->error(line, "Difference out of range", "*", "");
return 0;
}
return result.ValueOrDefault(0);
}
template <typename T>
T CheckedMul(base::CheckedNumeric<T> lhs,
base::CheckedNumeric<T> rhs,
TDiagnostics *diag,
const TSourceLoc &line)
{
ASSERT(lhs.IsValid() && rhs.IsValid());
auto result = lhs * rhs;
if (!result.IsValid())
{
diag->error(line, "Multiplication out of range", "*", "");
return 0;
}
return result.ValueOrDefault(0);
}
} // anonymous namespace
TConstantUnion::TConstantUnion() TConstantUnion::TConstantUnion()
{ {
iConst = 0; iConst = 0;
...@@ -221,20 +275,21 @@ bool TConstantUnion::operator<(const TConstantUnion &constant) const ...@@ -221,20 +275,21 @@ bool TConstantUnion::operator<(const TConstantUnion &constant) const
// static // static
TConstantUnion TConstantUnion::add(const TConstantUnion &lhs, TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
const TConstantUnion &rhs, const TConstantUnion &rhs,
TDiagnostics *diag) TDiagnostics *diag,
const TSourceLoc &line)
{ {
TConstantUnion returnValue; TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type); ASSERT(lhs.type == rhs.type);
switch (lhs.type) switch (lhs.type)
{ {
case EbtInt: case EbtInt:
returnValue.setIConst(lhs.iConst + rhs.iConst); returnValue.setIConst(CheckedSum<int>(lhs.iConst, rhs.iConst, diag, line));
break; break;
case EbtUInt: case EbtUInt:
returnValue.setUConst(lhs.uConst + rhs.uConst); returnValue.setUConst(CheckedSum<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break; break;
case EbtFloat: case EbtFloat:
returnValue.setFConst(lhs.fConst + rhs.fConst); returnValue.setFConst(CheckedSum<float>(lhs.fConst, rhs.fConst, diag, line));
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -246,20 +301,21 @@ TConstantUnion TConstantUnion::add(const TConstantUnion &lhs, ...@@ -246,20 +301,21 @@ TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
// static // static
TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs, TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
const TConstantUnion &rhs, const TConstantUnion &rhs,
TDiagnostics *diag) TDiagnostics *diag,
const TSourceLoc &line)
{ {
TConstantUnion returnValue; TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type); ASSERT(lhs.type == rhs.type);
switch (lhs.type) switch (lhs.type)
{ {
case EbtInt: case EbtInt:
returnValue.setIConst(lhs.iConst - rhs.iConst); returnValue.setIConst(CheckedDiff<int>(lhs.iConst, rhs.iConst, diag, line));
break; break;
case EbtUInt: case EbtUInt:
returnValue.setUConst(lhs.uConst - rhs.uConst); returnValue.setUConst(CheckedDiff<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break; break;
case EbtFloat: case EbtFloat:
returnValue.setFConst(lhs.fConst - rhs.fConst); returnValue.setFConst(CheckedDiff<float>(lhs.fConst, rhs.fConst, diag, line));
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
...@@ -271,20 +327,21 @@ TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs, ...@@ -271,20 +327,21 @@ TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
// static // static
TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs, TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
const TConstantUnion &rhs, const TConstantUnion &rhs,
TDiagnostics *diag) TDiagnostics *diag,
const TSourceLoc &line)
{ {
TConstantUnion returnValue; TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type); ASSERT(lhs.type == rhs.type);
switch (lhs.type) switch (lhs.type)
{ {
case EbtInt: case EbtInt:
returnValue.setIConst(lhs.iConst * rhs.iConst); returnValue.setIConst(CheckedMul<int>(lhs.iConst, rhs.iConst, diag, line));
break; break;
case EbtUInt: case EbtUInt:
returnValue.setUConst(lhs.uConst * rhs.uConst); returnValue.setUConst(CheckedMul<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break; break;
case EbtFloat: case EbtFloat:
returnValue.setFConst(lhs.fConst * rhs.fConst); returnValue.setFConst(CheckedMul<float>(lhs.fConst, rhs.fConst, diag, line));
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
......
...@@ -46,13 +46,16 @@ class TConstantUnion ...@@ -46,13 +46,16 @@ class TConstantUnion
bool operator<(const TConstantUnion &constant) const; bool operator<(const TConstantUnion &constant) const;
static TConstantUnion add(const TConstantUnion &lhs, static TConstantUnion add(const TConstantUnion &lhs,
const TConstantUnion &rhs, const TConstantUnion &rhs,
TDiagnostics *diag); TDiagnostics *diag,
const TSourceLoc &line);
static TConstantUnion sub(const TConstantUnion &lhs, static TConstantUnion sub(const TConstantUnion &lhs,
const TConstantUnion &rhs, const TConstantUnion &rhs,
TDiagnostics *diag); TDiagnostics *diag,
const TSourceLoc &line);
static TConstantUnion mul(const TConstantUnion &lhs, static TConstantUnion mul(const TConstantUnion &lhs,
const TConstantUnion &rhs, const TConstantUnion &rhs,
TDiagnostics *diag); TDiagnostics *diag,
const TSourceLoc &line);
TConstantUnion operator%(const TConstantUnion &constant) const; TConstantUnion operator%(const TConstantUnion &constant) const;
TConstantUnion operator>>(const TConstantUnion &constant) const; TConstantUnion operator>>(const TConstantUnion &constant) const;
TConstantUnion operator<<(const TConstantUnion &constant) const; TConstantUnion operator<<(const TConstantUnion &constant) const;
......
...@@ -1026,7 +1026,8 @@ TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics) ...@@ -1026,7 +1026,8 @@ TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics)
{ {
return nullptr; return nullptr;
} }
TConstantUnion *constArray = leftConstant->foldBinary(mOp, rightConstant, diagnostics); TConstantUnion *constArray =
leftConstant->foldBinary(mOp, rightConstant, diagnostics, mLeft->getLine());
// Nodes may be constant folded without being qualified as constant. // Nodes may be constant folded without being qualified as constant.
return CreateFoldedNode(constArray, this, mType.getQualifier()); return CreateFoldedNode(constArray, this, mType.getQualifier());
...@@ -1097,7 +1098,8 @@ TIntermTyped *TIntermAggregate::fold(TDiagnostics *diagnostics) ...@@ -1097,7 +1098,8 @@ TIntermTyped *TIntermAggregate::fold(TDiagnostics *diagnostics)
// //
TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op, TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
TIntermConstantUnion *rightNode, TIntermConstantUnion *rightNode,
TDiagnostics *diagnostics) TDiagnostics *diagnostics,
const TSourceLoc &line)
{ {
const TConstantUnion *leftArray = getUnionArrayPointer(); const TConstantUnion *leftArray = getUnionArrayPointer();
const TConstantUnion *rightArray = rightNode->getUnionArrayPointer(); const TConstantUnion *rightArray = rightNode->getUnionArrayPointer();
...@@ -1125,12 +1127,12 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op, ...@@ -1125,12 +1127,12 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpAdd: case EOpAdd:
resultArray = new TConstantUnion[objectSize]; resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++) for (size_t i = 0; i < objectSize; i++)
resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics); resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics, line);
break; break;
case EOpSub: case EOpSub:
resultArray = new TConstantUnion[objectSize]; resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++) for (size_t i = 0; i < objectSize; i++)
resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics); resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics, line);
break; break;
case EOpMul: case EOpMul:
...@@ -1138,11 +1140,12 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op, ...@@ -1138,11 +1140,12 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpMatrixTimesScalar: case EOpMatrixTimesScalar:
resultArray = new TConstantUnion[objectSize]; resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++) for (size_t i = 0; i < objectSize; i++)
resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics); resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics, line);
break; break;
case EOpMatrixTimesMatrix: case EOpMatrixTimesMatrix:
{ {
// TODO(jmadll): This code should check for overflows.
ASSERT(getType().getBasicType() == EbtFloat && rightNode->getBasicType() == EbtFloat); ASSERT(getType().getBasicType() == EbtFloat && rightNode->getBasicType() == EbtFloat);
const int leftCols = getCols(); const int leftCols = getCols();
...@@ -1244,6 +1247,7 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op, ...@@ -1244,6 +1247,7 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpMatrixTimesVector: case EOpMatrixTimesVector:
{ {
// TODO(jmadll): This code should check for overflows.
ASSERT(rightNode->getBasicType() == EbtFloat); ASSERT(rightNode->getBasicType() == EbtFloat);
const int matrixCols = getCols(); const int matrixCols = getCols();
...@@ -1266,6 +1270,7 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op, ...@@ -1266,6 +1270,7 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpVectorTimesMatrix: case EOpVectorTimesMatrix:
{ {
// TODO(jmadll): This code should check for overflows.
ASSERT(getType().getBasicType() == EbtFloat); ASSERT(getType().getBasicType() == EbtFloat);
const int matrixCols = rightNode->getType().getCols(); const int matrixCols = rightNode->getType().getCols();
......
...@@ -360,7 +360,8 @@ class TIntermConstantUnion : public TIntermTyped ...@@ -360,7 +360,8 @@ class TIntermConstantUnion : public TIntermTyped
TConstantUnion *foldBinary(TOperator op, TConstantUnion *foldBinary(TOperator op,
TIntermConstantUnion *rightNode, TIntermConstantUnion *rightNode,
TDiagnostics *diagnostics); TDiagnostics *diagnostics,
const TSourceLoc &line);
const TConstantUnion *foldIndexing(int index); const TConstantUnion *foldIndexing(int index);
TConstantUnion *foldUnaryNonComponentWise(TOperator op); TConstantUnion *foldUnaryNonComponentWise(TOperator op);
TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics); TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics);
......
...@@ -2270,6 +2270,71 @@ TEST_P(GLSLTest, NestedSequenceOperatorWithTernaryInside) ...@@ -2270,6 +2270,71 @@ TEST_P(GLSLTest, NestedSequenceOperatorWithTernaryInside)
EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green); EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
} }
// Test that multiplication ops are properly validated.
TEST_P(GLSLTest, FoldedIntProductOutOfBounds)
{
const std::string &fragmentShader =
"precision mediump float;\n"
"void main(void)\n"
"{\n"
" int prod = -2580 * 25800 * 25800;\n"
" gl_FragColor = vec4(float(prod));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
// Test that multiplication ops are properly validated.
TEST_P(GLSLTest_ES3, FoldedUIntProductOutOfBounds)
{
const std::string &fragmentShader =
"#version 300 es\n"
"precision mediump float;\n"
"void main()\n"
"{\n"
" unsigned int prod = 2580u * 25800u * 25800u;\n"
" gl_FragColor = vec4(float(prod));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
// Test that addition ops are properly validated.
TEST_P(GLSLTest, FoldedIntSumOutOfBounds)
{
const std::string &fragmentShader =
"precision mediump float;\n"
"void main(void)\n"
"{\n"
" int sum = 2147483647 + 2147483647;\n"
" gl_FragColor = vec4(float(sum));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
// Test that subtraction ops are properly validated.
TEST_P(GLSLTest, FoldedIntDifferenceOutOfBounds)
{
const std::string &fragmentShader =
"precision mediump float;\n"
"void main(void)\n"
"{\n"
" int diff = -2147483000 - 2147483000;\n"
" gl_FragColor = vec4(float(diff));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
} // anonymous namespace } // anonymous namespace
// Use this to select which configurations (e.g. which renderer, which GLES major version) these tests should be run against. // Use this to select which configurations (e.g. which renderer, which GLES major version) these tests should be run against.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment