Add expression complexity and call stack depth limits.

git-svn-id: https://angleproject.googlecode.com/svn/branches/dx11proto@2254 736b8ea6-26fd-11df-bfd4-992fa37f6226
parent b0f1b486
......@@ -110,6 +110,10 @@
'LinkIncremental': '2',
},
},
'xcode_settings': {
'COPY_PHASE_STRIP': 'NO',
'GCC_OPTIMIZATION_LEVEL': '0',
},
}, # Debug
'Release': {
'inherit_from': ['Common'],
......
......@@ -162,7 +162,13 @@ typedef enum {
// vec234, or mat234 type. The ShArrayIndexClampingStrategy enum,
// specified in the ShBuiltInResources when constructing the
// compiler, selects the strategy for the clamping implementation.
SH_CLAMP_INDIRECT_ARRAY_BOUNDS = 0x1000
SH_CLAMP_INDIRECT_ARRAY_BOUNDS = 0x1000,
// This flag limits the complexity of an expression.
SH_LIMIT_EXPRESSION_COMPLEXITY = 0x2000,
// This flag limits the depth of the call stack.
SH_LIMIT_CALL_STACK_DEPTH = 0x4000,
} ShCompileOptions;
// Defines alternate strategies for implementing array index clamping.
......@@ -225,6 +231,12 @@ typedef struct
// Selects a strategy to use when implementing array index clamping.
// Default is SH_CLAMP_WITH_CLAMP_INTRINSIC.
ShArrayIndexClampingStrategy ArrayIndexClampingStrategy;
// The maximum complexity an expression can be.
int MaxExpressionComplexity;
// The maximum depth a call stack can be.
int MaxCallStackDepth;
} ShBuiltInResources;
//
......
......@@ -72,8 +72,8 @@
'compiler/ConstantUnion.h',
'compiler/debug.cpp',
'compiler/debug.h',
'compiler/DetectRecursion.cpp',
'compiler/DetectRecursion.h',
'compiler/DetectCallDepth.cpp',
'compiler/DetectCallDepth.h',
'compiler/Diagnostics.h',
'compiler/Diagnostics.cpp',
'compiler/DirectiveHandler.h',
......
......@@ -5,7 +5,7 @@
//
#include "compiler/BuiltInFunctionEmulator.h"
#include "compiler/DetectRecursion.h"
#include "compiler/DetectCallDepth.h"
#include "compiler/ForLoopUnroll.h"
#include "compiler/Initialize.h"
#include "compiler/InitializeParseContext.h"
......@@ -104,6 +104,9 @@ TShHandleBase::~TShHandleBase() {
TCompiler::TCompiler(ShShaderType type, ShShaderSpec spec)
: shaderType(type),
shaderSpec(spec),
maxUniformVectors(0),
maxExpressionComplexity(0),
maxCallStackDepth(0),
fragmentPrecisionHigh(false),
clampingStrategy(SH_CLAMP_WITH_CLAMP_INTRINSIC),
builtInFunctionEmulator(type)
......@@ -122,6 +125,8 @@ bool TCompiler::Init(const ShBuiltInResources& resources)
maxUniformVectors = (shaderType == SH_VERTEX_SHADER) ?
resources.MaxVertexUniformVectors :
resources.MaxFragmentUniformVectors;
maxExpressionComplexity = resources.MaxExpressionComplexity;
maxCallStackDepth = resources.MaxCallStackDepth;
TScopedPoolAllocator scopedAlloc(&allocator, false);
// Generate built-in symbol table.
......@@ -185,7 +190,7 @@ bool TCompiler::compile(const char* const shaderStrings[],
success = intermediate.postProcess(root);
if (success)
success = detectRecursion(root);
success = detectCallDepth(root, infoSink, (compileOptions & SH_LIMIT_CALL_STACK_DEPTH) != 0);
if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING))
success = validateLimitations(root);
......@@ -208,6 +213,10 @@ bool TCompiler::compile(const char* const shaderStrings[],
if (success && (compileOptions & SH_CLAMP_INDIRECT_ARRAY_BOUNDS))
arrayBoundsClamper.MarkIndirectArrayBoundsForClamping(root);
// Disallow expressions deemed too complex.
if (success && (compileOptions & SH_LIMIT_EXPRESSION_COMPLEXITY))
success = limitExpressionComplexity(root);
// Call mapLongVariableNames() before collectAttribsUniforms() so in
// collectAttribsUniforms() we already have the mapped symbol names and
// we could composite mapped and original variable names.
......@@ -268,24 +277,27 @@ void TCompiler::clearResults()
nameMap.clear();
}
bool TCompiler::detectRecursion(TIntermNode* root)
bool TCompiler::detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth)
{
DetectRecursion detect;
DetectCallDepth detect(infoSink, limitCallStackDepth, maxCallStackDepth);
root->traverse(&detect);
switch (detect.detectRecursion()) {
case DetectRecursion::kErrorNone:
switch (detect.detectCallDepth()) {
case DetectCallDepth::kErrorNone:
return true;
case DetectRecursion::kErrorMissingMain:
case DetectCallDepth::kErrorMissingMain:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Missing main()";
return false;
case DetectRecursion::kErrorRecursion:
case DetectCallDepth::kErrorRecursion:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Function recursion detected";
return false;
case DetectCallDepth::kErrorMaxDepthExceeded:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Function call stack too deep";
return false;
default:
UNREACHABLE();
return false;
}
}
......@@ -327,6 +339,28 @@ bool TCompiler::enforceTimingRestrictions(TIntermNode* root, bool outputGraph)
}
}
bool TCompiler::limitExpressionComplexity(TIntermNode* root)
{
TIntermTraverser traverser;
root->traverse(&traverser);
TDependencyGraph graph(root);
for (TFunctionCallVector::const_iterator iter = graph.beginUserDefinedFunctionCalls();
iter != graph.endUserDefinedFunctionCalls();
++iter)
{
TGraphFunctionCall* samplerSymbol = *iter;
TDependencyGraphTraverser graphTraverser;
samplerSymbol->traverse(&graphTraverser);
}
if (traverser.getMaxDepth() > maxExpressionComplexity) {
infoSink.info << "Expression too complex.";
return false;
}
return true;
}
bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph)
{
RestrictFragmentShaderTiming restrictor(infoSink.info);
......
......@@ -4,21 +4,24 @@
// found in the LICENSE file.
//
#include "compiler/DetectRecursion.h"
#include "compiler/DetectCallDepth.h"
#include "compiler/InfoSink.h"
DetectRecursion::FunctionNode::FunctionNode(const TString& fname)
const int DetectCallDepth::FunctionNode::kInfiniteCallDepth;
DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
: name(fname),
visit(PreVisit)
{
}
const TString& DetectRecursion::FunctionNode::getName() const
const TString& DetectCallDepth::FunctionNode::getName() const
{
return name;
}
void DetectRecursion::FunctionNode::addCallee(
DetectRecursion::FunctionNode* callee)
void DetectCallDepth::FunctionNode::addCallee(
DetectCallDepth::FunctionNode* callee)
{
for (size_t i = 0; i < callees.size(); ++i) {
if (callees[i] == callee)
......@@ -27,21 +30,31 @@ void DetectRecursion::FunctionNode::addCallee(
callees.push_back(callee);
}
bool DetectRecursion::FunctionNode::detectRecursion()
int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
{
ASSERT(visit == PreVisit);
ASSERT(detectCallDepth);
int maxDepth = depth;
visit = InVisit;
for (size_t i = 0; i < callees.size(); ++i) {
switch (callees[i]->visit) {
case InVisit:
// cycle detected, i.e., recursion detected.
return true;
return kInfiniteCallDepth;
case PostVisit:
break;
case PreVisit: {
bool recursion = callees[i]->detectRecursion();
if (recursion)
return true;
// Check before we recurse so we don't go too depth
if (detectCallDepth->checkExceedsMaxDepth(depth))
return depth;
int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
// Check after we recurse so we can exit immediately and provide info.
if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
return callDepth;
}
maxDepth = std::max(callDepth, maxDepth);
break;
}
default:
......@@ -50,21 +63,29 @@ bool DetectRecursion::FunctionNode::detectRecursion()
}
}
visit = PostVisit;
return false;
return maxDepth;
}
DetectRecursion::DetectRecursion()
: currentFunction(NULL)
void DetectCallDepth::FunctionNode::reset()
{
visit = PreVisit;
}
DetectRecursion::~DetectRecursion()
DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
: TIntermTraverser(true, false, true, false),
currentFunction(NULL),
infoSink(infoSink),
maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
{
}
DetectCallDepth::~DetectCallDepth()
{
for (size_t i = 0; i < functions.size(); ++i)
delete functions[i];
}
bool DetectRecursion::visitAggregate(Visit visit, TIntermAggregate* node)
bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
{
switch (node->getOp())
{
......@@ -81,19 +102,21 @@ bool DetectRecursion::visitAggregate(Visit visit, TIntermAggregate* node)
currentFunction = new FunctionNode(node->getName());
functions.push_back(currentFunction);
}
} else if (visit == PostVisit) {
currentFunction = NULL;
}
break;
}
case EOpFunctionCall: {
// Function call.
if (visit == PreVisit) {
ASSERT(currentFunction != NULL);
FunctionNode* func = findFunctionByName(node->getName());
if (func == NULL) {
func = new FunctionNode(node->getName());
functions.push_back(func);
}
currentFunction->addCallee(func);
if (currentFunction)
currentFunction->addCallee(func);
}
break;
}
......@@ -103,17 +126,56 @@ bool DetectRecursion::visitAggregate(Visit visit, TIntermAggregate* node)
return true;
}
DetectRecursion::ErrorCode DetectRecursion::detectRecursion()
bool DetectCallDepth::checkExceedsMaxDepth(int depth)
{
return depth >= maxDepth;
}
void DetectCallDepth::resetFunctionNodes()
{
for (size_t i = 0; i < functions.size(); ++i) {
functions[i]->reset();
}
}
DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
{
FunctionNode* main = findFunctionByName("main(");
if (main == NULL)
return kErrorMissingMain;
if (main->detectRecursion())
currentFunction = NULL;
resetFunctionNodes();
int maxCallDepth = func->detectCallDepth(this, 1);
if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
return kErrorRecursion;
if (maxCallDepth >= maxDepth)
return kErrorMaxDepthExceeded;
return kErrorNone;
}
DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
{
if (maxDepth != FunctionNode::kInfiniteCallDepth) {
// Check all functions because the driver may fail on them
// TODO: Before detectingRecursion, strip unused functions.
for (size_t i = 0; i < functions.size(); ++i) {
ErrorCode error = detectCallDepthForFunction(functions[i]);
if (error != kErrorNone)
return error;
}
} else {
FunctionNode* main = findFunctionByName("main(");
if (main == NULL)
return kErrorMissingMain;
return detectCallDepthForFunction(main);
}
return kErrorNone;
}
DetectRecursion::FunctionNode* DetectRecursion::findFunctionByName(
DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
const TString& name)
{
for (size_t i = 0; i < functions.size(); ++i) {
......
......@@ -9,28 +9,36 @@
#include "GLSLANG/ShaderLang.h"
#include <limits.h>
#include "compiler/intermediate.h"
#include "compiler/VariableInfo.h"
class TInfoSink;
// Traverses intermediate tree to detect function recursion.
class DetectRecursion : public TIntermTraverser {
class DetectCallDepth : public TIntermTraverser {
public:
enum ErrorCode {
kErrorMissingMain,
kErrorRecursion,
kErrorMaxDepthExceeded,
kErrorNone
};
DetectRecursion();
~DetectRecursion();
DetectCallDepth(TInfoSink& infoSync, bool limitCallStackDepth, int maxCallStackDepth);
~DetectCallDepth();
virtual bool visitAggregate(Visit, TIntermAggregate*);
ErrorCode detectRecursion();
bool checkExceedsMaxDepth(int depth);
ErrorCode detectCallDepth();
private:
class FunctionNode {
public:
static const int kInfiniteCallDepth = INT_MAX;
FunctionNode(const TString& fname);
const TString& getName() const;
......@@ -38,8 +46,11 @@ private:
// If a function is already in the callee list, this becomes a no-op.
void addCallee(FunctionNode* callee);
// Return true if recursive function calls are detected.
bool detectRecursion();
// Returns kInifinityCallDepth if recursive function calls are detected.
int detectCallDepth(DetectCallDepth* detectCallDepth, int depth);
// Reset state.
void reset();
private:
// mangled function name is unique.
......@@ -51,10 +62,19 @@ private:
Visit visit;
};
ErrorCode detectCallDepthForFunction(FunctionNode* func);
FunctionNode* findFunctionByName(const TString& name);
void resetFunctionNodes();
TInfoSink& getInfoSink() { return infoSink; }
TVector<FunctionNode*> functions;
FunctionNode* currentFunction;
TInfoSink& infoSink;
int maxDepth;
DetectCallDepth(const DetectCallDepth&);
void operator=(const DetectCallDepth&);
};
#endif // COMPILER_DETECT_RECURSION_H_
......@@ -83,8 +83,8 @@ protected:
bool InitBuiltInSymbolTable(const ShBuiltInResources& resources);
// Clears the results from the previous compilation.
void clearResults();
// Return true if function recursion is detected.
bool detectRecursion(TIntermNode* root);
// Return true if function recursion is detected or call depth exceeded.
bool detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth);
// Rewrites a shader's intermediate tree according to the CSS Shaders spec.
void rewriteCSSShader(TIntermNode* root);
// Returns true if the given shader does not exceed the minimum
......@@ -106,6 +106,8 @@ protected:
// Returns true if the shader does not use sampler dependent values to affect control
// flow or in operations whose time can depend on the input values.
bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph);
// Return true if the maximum expression complexity below the limit.
bool limitExpressionComplexity(TIntermNode* root);
// Get built-in extensions with default behavior.
const TExtensionBehavior& getExtensionBehavior() const;
// Get the resources set by InitBuiltInSymbolTable
......@@ -120,6 +122,8 @@ private:
ShShaderSpec shaderSpec;
int maxUniformVectors;
int maxExpressionComplexity;
int maxCallStackDepth;
ShBuiltInResources compileResources;
......
......@@ -18,6 +18,7 @@
#include "GLSLANG/ShaderLang.h"
#include <algorithm>
#include "compiler/Common.h"
#include "compiler/Types.h"
#include "compiler/ConstantUnion.h"
......@@ -546,7 +547,8 @@ public:
inVisit(inVisit),
postVisit(postVisit),
rightToLeft(rightToLeft),
depth(0) {}
depth(0),
maxDepth(0) {}
virtual ~TIntermTraverser() {};
virtual void visitSymbol(TIntermSymbol*) {}
......@@ -558,7 +560,8 @@ public:
virtual bool visitLoop(Visit visit, TIntermLoop*) {return true;}
virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
void incrementDepth() {depth++;}
int getMaxDepth() const {return maxDepth;}
void incrementDepth() {depth++; maxDepth = std::max(maxDepth, depth); }
void decrementDepth() {depth--;}
// Return the original name if hash function pointer is NULL;
......@@ -572,6 +575,7 @@ public:
protected:
int depth;
int maxDepth;
};
#endif // __INTERMEDIATE_H
......@@ -141,7 +141,7 @@
<ClCompile Include="BuiltInFunctionEmulator.cpp" />
<ClCompile Include="Compiler.cpp" />
<ClCompile Include="debug.cpp" />
<ClCompile Include="DetectRecursion.cpp" />
<ClCompile Include="DetectCallDepth.cpp" />
<ClCompile Include="Diagnostics.cpp" />
<ClCompile Include="DirectiveHandler.cpp" />
<ClCompile Include="ForLoopUnroll.cpp" />
......@@ -231,7 +231,7 @@
<ClInclude Include="Common.h" />
<ClInclude Include="ConstantUnion.h" />
<ClInclude Include="debug.h" />
<ClInclude Include="DetectRecursion.h" />
<ClInclude Include="DetectCallDepth.h" />
<ClInclude Include="Diagnostics.h" />
<ClInclude Include="DirectiveHandler.h" />
<ClInclude Include="ForLoopUnroll.h" />
......
......@@ -38,7 +38,7 @@
<ClCompile Include="debug.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="DetectRecursion.cpp">
<ClCompile Include="DetectCallDepth.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="Diagnostics.cpp">
......@@ -154,7 +154,7 @@
<ClInclude Include="debug.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="DetectRecursion.h">
<ClInclude Include="DetectCallDepth.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="Diagnostics.h">
......
......@@ -67,7 +67,7 @@
'target_name': 'compiler_tests',
'type': 'executable',
'dependencies': [
'../src/build_angle.gyp:translator_common',
'../src/build_angle.gyp:translator_glsl',
'gtest',
'gmock',
],
......@@ -79,6 +79,7 @@
],
'sources': [
'../third_party/googlemock/src/gmock_main.cc',
'compiler_tests/ExpressionLimit_test.cpp',
'compiler_tests/VariablePacker_test.cpp',
],
},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment