Commit b1762df4 by zmo@google.com

Detect function recursion and reject a shader if detected.

ANGLEBUG=191 TEST=shaders with function recursion are rejected. Review URL: http://codereview.appspot.com/4808061 git-svn-id: https://angleproject.googlecode.com/svn/trunk@711 736b8ea6-26fd-11df-bfd4-992fa37f6226
parent bb1d1713
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
'compiler/ConstantUnion.h', 'compiler/ConstantUnion.h',
'compiler/debug.cpp', 'compiler/debug.cpp',
'compiler/debug.h', 'compiler/debug.h',
'compiler/DetectRecursion.cpp',
'compiler/DetectRecursion.h',
'compiler/glslang.h', 'compiler/glslang.h',
'compiler/glslang_lex.cpp', 'compiler/glslang_lex.cpp',
'compiler/glslang_tab.cpp', 'compiler/glslang_tab.cpp',
......
#define MAJOR_VERSION 0 #define MAJOR_VERSION 0
#define MINOR_VERSION 0 #define MINOR_VERSION 0
#define BUILD_VERSION 0 #define BUILD_VERSION 0
#define BUILD_REVISION 709 #define BUILD_REVISION 710
#define STRINGIFY(x) #x #define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x) #define MACRO_STRINGIFY(x) STRINGIFY(x)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
// found in the LICENSE file. // found in the LICENSE file.
// //
#include "compiler/DetectRecursion.h"
#include "compiler/Initialize.h" #include "compiler/Initialize.h"
#include "compiler/ParseHelper.h" #include "compiler/ParseHelper.h"
#include "compiler/ShHandle.h" #include "compiler/ShHandle.h"
...@@ -147,13 +148,16 @@ bool TCompiler::compile(const char* const shaderStrings[], ...@@ -147,13 +148,16 @@ bool TCompiler::compile(const char* const shaderStrings[],
TIntermNode* root = parseContext.treeRoot; TIntermNode* root = parseContext.treeRoot;
success = intermediate.postProcess(root); success = intermediate.postProcess(root);
if (success)
success = detectRecursion(root);
if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING)) if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING))
success = validateLimitations(root); success = validateLimitations(root);
// Call mapLongVariableNames() before collectAttribsUniforms() so in // Call mapLongVariableNames() before collectAttribsUniforms() so in
// collectAttribsUniforms() we already have the mapped symbol names and // collectAttribsUniforms() we already have the mapped symbol names and
// we could composite mapped and original variable names. // we could composite mapped and original variable names.
if (compileOptions & SH_MAP_LONG_VARIABLE_NAMES) if (success && (compileOptions & SH_MAP_LONG_VARIABLE_NAMES))
mapLongVariableNames(root); mapLongVariableNames(root);
if (success && (compileOptions & SH_ATTRIBUTES_UNIFORMS)) if (success && (compileOptions & SH_ATTRIBUTES_UNIFORMS))
...@@ -195,6 +199,25 @@ void TCompiler::clearResults() ...@@ -195,6 +199,25 @@ void TCompiler::clearResults()
uniforms.clear(); uniforms.clear();
} }
bool TCompiler::detectRecursion(TIntermNode* root)
{
DetectRecursion detect;
root->traverse(&detect);
switch (detect.detectRecursion()) {
case DetectRecursion::kErrorNone:
return true;
case DetectRecursion::kErrorMissingMain:
infoSink.info.message(EPrefixError, "Missing main()");
return false;
case DetectRecursion::kErrorRecursion:
infoSink.info.message(EPrefixError, "Function recursion detected");
return false;
default:
UNREACHABLE();
return false;
}
}
bool TCompiler::validateLimitations(TIntermNode* root) { bool TCompiler::validateLimitations(TIntermNode* root) {
ValidateLimitations validate(shaderType, infoSink.info); ValidateLimitations validate(shaderType, infoSink.info);
root->traverse(&validate); root->traverse(&validate);
......
//
// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/DetectRecursion.h"
DetectRecursion::FunctionNode::FunctionNode(const TString& fname)
: name(fname),
visit(PreVisit)
{
}
const TString& DetectRecursion::FunctionNode::getName() const
{
return name;
}
void DetectRecursion::FunctionNode::addCallee(
DetectRecursion::FunctionNode* callee)
{
for (size_t i = 0; i < callees.size(); ++i) {
if (callees[i] == callee)
return;
}
callees.push_back(callee);
}
bool DetectRecursion::FunctionNode::detectRecursion()
{
ASSERT(visit == PreVisit);
visit = InVisit;
for (size_t i = 0; i < callees.size(); ++i) {
switch (callees[i]->visit) {
case InVisit:
// cycle detected, i.e., recursion detected.
return true;
case PostVisit:
break;
case PreVisit: {
bool recursion = callees[i]->detectRecursion();
if (recursion)
return true;
break;
}
default:
UNREACHABLE();
break;
}
}
visit = PostVisit;
return false;
}
DetectRecursion::DetectRecursion()
: currentFunction(NULL)
{
}
DetectRecursion::~DetectRecursion()
{
for (int i = 0; i < functions.size(); ++i)
delete functions[i];
}
void DetectRecursion::visitSymbol(TIntermSymbol*)
{
}
void DetectRecursion::visitConstantUnion(TIntermConstantUnion*)
{
}
bool DetectRecursion::visitBinary(Visit, TIntermBinary*)
{
return true;
}
bool DetectRecursion::visitUnary(Visit, TIntermUnary*)
{
return true;
}
bool DetectRecursion::visitSelection(Visit, TIntermSelection*)
{
return true;
}
bool DetectRecursion::visitAggregate(Visit visit, TIntermAggregate* node)
{
switch (node->getOp())
{
case EOpPrototype:
// Function declaration.
// Don't add FunctionNode here because node->getName() is the
// unmangled function name.
break;
case EOpFunction: {
// Function definition.
if (visit == PreVisit) {
currentFunction = findFunctionByName(node->getName());
if (currentFunction == NULL) {
currentFunction = new FunctionNode(node->getName());
functions.push_back(currentFunction);
}
}
break;
}
case EOpFunctionCall: {
// Function call.
if (visit == PreVisit) {
ASSERT(currentFunction != NULL);
FunctionNode* func = findFunctionByName(node->getName());
if (func == NULL) {
func = new FunctionNode(node->getName());
functions.push_back(func);
}
currentFunction->addCallee(func);
}
break;
}
default:
break;
}
return true;
}
bool DetectRecursion::visitLoop(Visit, TIntermLoop*)
{
return true;
}
bool DetectRecursion::visitBranch(Visit, TIntermBranch*)
{
return true;
}
DetectRecursion::ErrorCode DetectRecursion::detectRecursion()
{
FunctionNode* main = findFunctionByName("main(");
if (main == NULL)
return kErrorMissingMain;
if (main->detectRecursion())
return kErrorRecursion;
return kErrorNone;
}
DetectRecursion::FunctionNode* DetectRecursion::findFunctionByName(
const TString& name)
{
for (size_t i = 0; i < functions.size(); ++i) {
if (functions[i]->getName() == name)
return functions[i];
}
return NULL;
}
//
// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#ifndef COMPILER_DETECT_RECURSION_H_
#define COMPILER_DETECT_RECURSION_H_
#include "GLSLANG/ShaderLang.h"
#include "compiler/intermediate.h"
#include "compiler/VariableInfo.h"
// Traverses intermediate tree to detect function recursion.
class DetectRecursion : public TIntermTraverser {
public:
enum ErrorCode {
kErrorMissingMain,
kErrorRecursion,
kErrorNone
};
DetectRecursion();
~DetectRecursion();
virtual void visitSymbol(TIntermSymbol*);
virtual void visitConstantUnion(TIntermConstantUnion*);
virtual bool visitBinary(Visit, TIntermBinary*);
virtual bool visitUnary(Visit, TIntermUnary*);
virtual bool visitSelection(Visit, TIntermSelection*);
virtual bool visitAggregate(Visit, TIntermAggregate*);
virtual bool visitLoop(Visit, TIntermLoop*);
virtual bool visitBranch(Visit, TIntermBranch*);
ErrorCode detectRecursion();
private:
class FunctionNode {
public:
FunctionNode(const TString& fname);
const TString& getName() const;
// If a function is already in the callee list, this becomes a no-op.
void addCallee(FunctionNode* callee);
// Return true if recursive function calls are detected.
bool detectRecursion();
private:
// mangled function name is unique.
TString name;
// functions that are directly called by this function.
TVector<FunctionNode*> callees;
Visit visit;
};
FunctionNode* findFunctionByName(const TString& name);
TVector<FunctionNode*> functions;
FunctionNode* currentFunction;
};
#endif // COMPILER_DETECT_RECURSION_H_
...@@ -66,6 +66,8 @@ protected: ...@@ -66,6 +66,8 @@ protected:
bool InitBuiltInSymbolTable(const ShBuiltInResources& resources); bool InitBuiltInSymbolTable(const ShBuiltInResources& resources);
// Clears the results from the previous compilation. // Clears the results from the previous compilation.
void clearResults(); void clearResults();
// Return true if function recursion is detected.
bool detectRecursion(TIntermNode* root);
// Returns true if the given shader does not exceed the minimum // Returns true if the given shader does not exceed the minimum
// functionality mandated in GLSL 1.0 spec Appendix A. // functionality mandated in GLSL 1.0 spec Appendix A.
bool validateLimitations(TIntermNode* root); bool validateLimitations(TIntermNode* root);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment