Commit 5db69f57 by Jamie Madill Committed by Commit Bot

Add robust math to constant folding.

Previously our multiplication and other operators could do overflows, which can lead to security bugs. BUG=chromium:637050 Change-Id: Icee22a87909e205b71bda1c5bc1627fcf5e26e90 Reviewed-on: https://chromium-review.googlesource.com/382678 Commit-Queue: Jamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent a5615c69
......@@ -57,7 +57,7 @@ void InitializeDebugAnnotations(DebugAnnotator *debugAnnotator);
void UninitializeDebugAnnotations();
bool DebugAnnotationsActive();
}
} // namespace gl
#if defined(ANGLE_ENABLE_DEBUG_TRACE) || defined(ANGLE_ENABLE_DEBUG_ANNOTATIONS)
#define ANGLE_TRACE_ENABLED
......@@ -153,10 +153,9 @@ bool DebugAnnotationsActive();
// A macro for code which is not expected to be reached under valid assumptions
#if !defined(NDEBUG)
#define UNREACHABLE() { \
ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__); \
assert(false); \
} ANGLE_EMPTY_STATEMENT
#define UNREACHABLE() \
ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__), \
assert(false)
#else
#define UNREACHABLE() ERR("\t! Unreachable reached: %s(%d)\n", __FUNCTION__, __LINE__)
#endif
......
......@@ -16,7 +16,7 @@
// Unfortunately ANGLE relies on ASSERT being an empty statement, which these libs don't respect.
#ifndef NOTREACHED
#define NOTREACHED() 0
#define NOTREACHED() UNREACHABLE()
#endif
#endif // BASE_LOGGING_H_
\ No newline at end of file
#endif // BASE_LOGGING_H_
......@@ -7,8 +7,62 @@
#include "compiler/translator/ConstantUnion.h"
#include "base/numerics/safe_math.h"
#include "compiler/translator/Diagnostics.h"
namespace
{
template <typename T>
T CheckedSum(base::CheckedNumeric<T> lhs,
base::CheckedNumeric<T> rhs,
TDiagnostics *diag,
const TSourceLoc &line)
{
ASSERT(lhs.IsValid() && rhs.IsValid());
auto result = lhs + rhs;
if (!result.IsValid())
{
diag->error(line, "Addition out of range", "*", "");
return 0;
}
return result.ValueOrDefault(0);
}
template <typename T>
T CheckedDiff(base::CheckedNumeric<T> lhs,
base::CheckedNumeric<T> rhs,
TDiagnostics *diag,
const TSourceLoc &line)
{
ASSERT(lhs.IsValid() && rhs.IsValid());
auto result = lhs - rhs;
if (!result.IsValid())
{
diag->error(line, "Difference out of range", "*", "");
return 0;
}
return result.ValueOrDefault(0);
}
template <typename T>
T CheckedMul(base::CheckedNumeric<T> lhs,
base::CheckedNumeric<T> rhs,
TDiagnostics *diag,
const TSourceLoc &line)
{
ASSERT(lhs.IsValid() && rhs.IsValid());
auto result = lhs * rhs;
if (!result.IsValid())
{
diag->error(line, "Multiplication out of range", "*", "");
return 0;
}
return result.ValueOrDefault(0);
}
} // anonymous namespace
TConstantUnion::TConstantUnion()
{
iConst = 0;
......@@ -221,20 +275,21 @@ bool TConstantUnion::operator<(const TConstantUnion &constant) const
// static
TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
const TConstantUnion &rhs,
TDiagnostics *diag)
TDiagnostics *diag,
const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
returnValue.setIConst(lhs.iConst + rhs.iConst);
returnValue.setIConst(CheckedSum<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
returnValue.setUConst(lhs.uConst + rhs.uConst);
returnValue.setUConst(CheckedSum<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
returnValue.setFConst(lhs.fConst + rhs.fConst);
returnValue.setFConst(CheckedSum<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
......@@ -246,20 +301,21 @@ TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
// static
TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
const TConstantUnion &rhs,
TDiagnostics *diag)
TDiagnostics *diag,
const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
returnValue.setIConst(lhs.iConst - rhs.iConst);
returnValue.setIConst(CheckedDiff<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
returnValue.setUConst(lhs.uConst - rhs.uConst);
returnValue.setUConst(CheckedDiff<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
returnValue.setFConst(lhs.fConst - rhs.fConst);
returnValue.setFConst(CheckedDiff<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
......@@ -271,20 +327,21 @@ TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
// static
TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
const TConstantUnion &rhs,
TDiagnostics *diag)
TDiagnostics *diag,
const TSourceLoc &line)
{
TConstantUnion returnValue;
ASSERT(lhs.type == rhs.type);
switch (lhs.type)
{
case EbtInt:
returnValue.setIConst(lhs.iConst * rhs.iConst);
returnValue.setIConst(CheckedMul<int>(lhs.iConst, rhs.iConst, diag, line));
break;
case EbtUInt:
returnValue.setUConst(lhs.uConst * rhs.uConst);
returnValue.setUConst(CheckedMul<unsigned int>(lhs.uConst, rhs.uConst, diag, line));
break;
case EbtFloat:
returnValue.setFConst(lhs.fConst * rhs.fConst);
returnValue.setFConst(CheckedMul<float>(lhs.fConst, rhs.fConst, diag, line));
break;
default:
UNREACHABLE();
......
......@@ -46,13 +46,16 @@ class TConstantUnion
bool operator<(const TConstantUnion &constant) const;
static TConstantUnion add(const TConstantUnion &lhs,
const TConstantUnion &rhs,
TDiagnostics *diag);
TDiagnostics *diag,
const TSourceLoc &line);
static TConstantUnion sub(const TConstantUnion &lhs,
const TConstantUnion &rhs,
TDiagnostics *diag);
TDiagnostics *diag,
const TSourceLoc &line);
static TConstantUnion mul(const TConstantUnion &lhs,
const TConstantUnion &rhs,
TDiagnostics *diag);
TDiagnostics *diag,
const TSourceLoc &line);
TConstantUnion operator%(const TConstantUnion &constant) const;
TConstantUnion operator>>(const TConstantUnion &constant) const;
TConstantUnion operator<<(const TConstantUnion &constant) const;
......
......@@ -1026,7 +1026,8 @@ TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics)
{
return nullptr;
}
TConstantUnion *constArray = leftConstant->foldBinary(mOp, rightConstant, diagnostics);
TConstantUnion *constArray =
leftConstant->foldBinary(mOp, rightConstant, diagnostics, mLeft->getLine());
// Nodes may be constant folded without being qualified as constant.
return CreateFoldedNode(constArray, this, mType.getQualifier());
......@@ -1097,7 +1098,8 @@ TIntermTyped *TIntermAggregate::fold(TDiagnostics *diagnostics)
//
TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
TIntermConstantUnion *rightNode,
TDiagnostics *diagnostics)
TDiagnostics *diagnostics,
const TSourceLoc &line)
{
const TConstantUnion *leftArray = getUnionArrayPointer();
const TConstantUnion *rightArray = rightNode->getUnionArrayPointer();
......@@ -1125,12 +1127,12 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpAdd:
resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++)
resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics);
resultArray[i] = TConstantUnion::add(leftArray[i], rightArray[i], diagnostics, line);
break;
case EOpSub:
resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++)
resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics);
resultArray[i] = TConstantUnion::sub(leftArray[i], rightArray[i], diagnostics, line);
break;
case EOpMul:
......@@ -1138,11 +1140,12 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpMatrixTimesScalar:
resultArray = new TConstantUnion[objectSize];
for (size_t i = 0; i < objectSize; i++)
resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics);
resultArray[i] = TConstantUnion::mul(leftArray[i], rightArray[i], diagnostics, line);
break;
case EOpMatrixTimesMatrix:
{
// TODO(jmadll): This code should check for overflows.
ASSERT(getType().getBasicType() == EbtFloat && rightNode->getBasicType() == EbtFloat);
const int leftCols = getCols();
......@@ -1244,6 +1247,7 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpMatrixTimesVector:
{
// TODO(jmadll): This code should check for overflows.
ASSERT(rightNode->getBasicType() == EbtFloat);
const int matrixCols = getCols();
......@@ -1266,6 +1270,7 @@ TConstantUnion *TIntermConstantUnion::foldBinary(TOperator op,
case EOpVectorTimesMatrix:
{
// TODO(jmadll): This code should check for overflows.
ASSERT(getType().getBasicType() == EbtFloat);
const int matrixCols = rightNode->getType().getCols();
......
......@@ -360,7 +360,8 @@ class TIntermConstantUnion : public TIntermTyped
TConstantUnion *foldBinary(TOperator op,
TIntermConstantUnion *rightNode,
TDiagnostics *diagnostics);
TDiagnostics *diagnostics,
const TSourceLoc &line);
const TConstantUnion *foldIndexing(int index);
TConstantUnion *foldUnaryNonComponentWise(TOperator op);
TConstantUnion *foldUnaryComponentWise(TOperator op, TDiagnostics *diagnostics);
......
......@@ -2270,6 +2270,71 @@ TEST_P(GLSLTest, NestedSequenceOperatorWithTernaryInside)
EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
}
// Test that multiplication ops are properly validated.
TEST_P(GLSLTest, FoldedIntProductOutOfBounds)
{
const std::string &fragmentShader =
"precision mediump float;\n"
"void main(void)\n"
"{\n"
" int prod = -2580 * 25800 * 25800;\n"
" gl_FragColor = vec4(float(prod));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
// Test that multiplication ops are properly validated.
TEST_P(GLSLTest_ES3, FoldedUIntProductOutOfBounds)
{
const std::string &fragmentShader =
"#version 300 es\n"
"precision mediump float;\n"
"void main()\n"
"{\n"
" unsigned int prod = 2580u * 25800u * 25800u;\n"
" gl_FragColor = vec4(float(prod));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
// Test that addition ops are properly validated.
TEST_P(GLSLTest, FoldedIntSumOutOfBounds)
{
const std::string &fragmentShader =
"precision mediump float;\n"
"void main(void)\n"
"{\n"
" int sum = 2147483647 + 2147483647;\n"
" gl_FragColor = vec4(float(sum));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
// Test that subtraction ops are properly validated.
TEST_P(GLSLTest, FoldedIntDifferenceOutOfBounds)
{
const std::string &fragmentShader =
"precision mediump float;\n"
"void main(void)\n"
"{\n"
" int diff = -2147483000 - 2147483000;\n"
" gl_FragColor = vec4(float(diff));\n"
"}\n";
GLuint program = CompileProgram(mSimpleVSSource, fragmentShader);
EXPECT_EQ(0u, program);
glDeleteProgram(program);
}
} // anonymous namespace
// Use this to select which configurations (e.g. which renderer, which GLES major version) these tests should be run against.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment