Commit 4002e92a by Olli Etuaho Committed by Commit Bot

Guard traversers used during parsing against stack overflow

Traversers used during parsing can be vulnerable to stack overflow since the AST has not yet been validated for max depth. Make sure to check for traversal depth in traversers used during parsing. We set the maximum traversal depth in ValidateGlobalInitializer and ValidateSwitchStatementList to 256, which matches the default value for validating general AST complexity. The depth check is on regardless of compiler options. In case the traversers go over the maximum traversal depth, they fail validation. BUG=angleproject:2453 TEST=angle_unittests Change-Id: I89ba576e8ef69663ba35d7b9050a6da319f1757c Reviewed-on: https://chromium-review.googlesource.com/995795Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent dd196e0b
...@@ -18,24 +18,10 @@ namespace ...@@ -18,24 +18,10 @@ namespace
class MaxDepthTraverser : public TIntermTraverser class MaxDepthTraverser : public TIntermTraverser
{ {
public: public:
MaxDepthTraverser(int depthLimit) : TIntermTraverser(true, true, false), mDepthLimit(depthLimit) MaxDepthTraverser(int depthLimit) : TIntermTraverser(true, false, false, nullptr)
{ {
setMaxAllowedDepth(depthLimit);
} }
bool visitBinary(Visit, TIntermBinary *) override { return depthCheck(); }
bool visitUnary(Visit, TIntermUnary *) override { return depthCheck(); }
bool visitTernary(Visit, TIntermTernary *) override { return depthCheck(); }
bool visitSwizzle(Visit, TIntermSwizzle *) override { return depthCheck(); }
bool visitIfElse(Visit, TIntermIfElse *) override { return depthCheck(); }
bool visitAggregate(Visit, TIntermAggregate *) override { return depthCheck(); }
bool visitBlock(Visit, TIntermBlock *) override { return depthCheck(); }
bool visitLoop(Visit, TIntermLoop *) override { return depthCheck(); }
bool visitBranch(Visit, TIntermBranch *) override { return depthCheck(); }
protected:
bool depthCheck() const { return mMaxDepth < mDepthLimit; }
int mDepthLimit;
}; };
} // anonymous namespace } // anonymous namespace
......
...@@ -846,7 +846,7 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node) ...@@ -846,7 +846,7 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
{ {
TInfoSinkBase &out = objSink(); TInfoSinkBase &out = objSink();
// Scope the blocks except when at the global scope. // Scope the blocks except when at the global scope.
if (mDepth > 0) if (getCurrentTraversalDepth() > 0)
{ {
out << "{\n"; out << "{\n";
} }
...@@ -863,7 +863,7 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node) ...@@ -863,7 +863,7 @@ bool TOutputGLSLBase::visitBlock(Visit visit, TIntermBlock *node)
} }
// Scope the blocks except when at the global scope. // Scope the blocks except when at the global scope.
if (mDepth > 0) if (getCurrentTraversalDepth() > 0)
{ {
out << "}\n"; out << "}\n";
} }
......
...@@ -31,7 +31,10 @@ void OutputFunction(TInfoSinkBase &out, const char *str, const TFunction *func) ...@@ -31,7 +31,10 @@ void OutputFunction(TInfoSinkBase &out, const char *str, const TFunction *func)
class TOutputTraverser : public TIntermTraverser class TOutputTraverser : public TIntermTraverser
{ {
public: public:
TOutputTraverser(TInfoSinkBase &out) : TIntermTraverser(true, false, false), mOut(out) {} TOutputTraverser(TInfoSinkBase &out)
: TIntermTraverser(true, false, false), mOut(out), mIndentDepth(0)
{
}
protected: protected:
void visitSymbol(TIntermSymbol *) override; void visitSymbol(TIntermSymbol *) override;
...@@ -52,7 +55,10 @@ class TOutputTraverser : public TIntermTraverser ...@@ -52,7 +55,10 @@ class TOutputTraverser : public TIntermTraverser
bool visitLoop(Visit visit, TIntermLoop *) override; bool visitLoop(Visit visit, TIntermLoop *) override;
bool visitBranch(Visit visit, TIntermBranch *) override; bool visitBranch(Visit visit, TIntermBranch *) override;
int getCurrentIndentDepth() const { return mIndentDepth + getCurrentTraversalDepth(); }
TInfoSinkBase &mOut; TInfoSinkBase &mOut;
int mIndentDepth;
}; };
// //
...@@ -79,7 +85,7 @@ void OutputTreeText(TInfoSinkBase &out, TIntermNode *node, const int depth) ...@@ -79,7 +85,7 @@ void OutputTreeText(TInfoSinkBase &out, TIntermNode *node, const int depth)
void TOutputTraverser::visitSymbol(TIntermSymbol *node) void TOutputTraverser::visitSymbol(TIntermSymbol *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->variable().symbolType() == SymbolType::Empty) if (node->variable().symbolType() == SymbolType::Empty)
{ {
...@@ -96,7 +102,7 @@ void TOutputTraverser::visitSymbol(TIntermSymbol *node) ...@@ -96,7 +102,7 @@ void TOutputTraverser::visitSymbol(TIntermSymbol *node)
bool TOutputTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node) bool TOutputTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "vector swizzle ("; mOut << "vector swizzle (";
node->writeOffsetsAsXYZW(&mOut); node->writeOffsetsAsXYZW(&mOut);
mOut << ")"; mOut << ")";
...@@ -108,7 +114,7 @@ bool TOutputTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node) ...@@ -108,7 +114,7 @@ bool TOutputTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node) bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
switch (node->getOp()) switch (node->getOp())
{ {
...@@ -270,7 +276,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -270,7 +276,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
TIntermConstantUnion *intermConstantUnion = node->getRight()->getAsConstantUnion(); TIntermConstantUnion *intermConstantUnion = node->getRight()->getAsConstantUnion();
ASSERT(intermConstantUnion); ASSERT(intermConstantUnion);
OutputTreeText(mOut, intermConstantUnion, mDepth + 1); OutputTreeText(mOut, intermConstantUnion, getCurrentIndentDepth() + 1);
// The following code finds the field name from the constant union // The following code finds the field name from the constant union
const TConstantUnion *constantUnion = intermConstantUnion->getConstantValue(); const TConstantUnion *constantUnion = intermConstantUnion->getConstantValue();
...@@ -294,7 +300,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -294,7 +300,7 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node) bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
switch (node->getOp()) switch (node->getOp())
{ {
...@@ -348,22 +354,21 @@ bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node) ...@@ -348,22 +354,21 @@ bool TOutputTraverser::visitUnary(Visit visit, TIntermUnary *node)
bool TOutputTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) bool TOutputTraverser::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Function Definition:\n"; mOut << "Function Definition:\n";
mOut << "\n";
return true; return true;
} }
bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node) bool TOutputTraverser::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Invariant Declaration:\n"; mOut << "Invariant Declaration:\n";
return true; return true;
} }
void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node) void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
OutputFunction(mOut, "Function Prototype", node->getFunction()); OutputFunction(mOut, "Function Prototype", node->getFunction());
mOut << " (" << node->getCompleteString() << ")"; mOut << " (" << node->getCompleteString() << ")";
mOut << "\n"; mOut << "\n";
...@@ -371,7 +376,7 @@ void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node) ...@@ -371,7 +376,7 @@ void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
for (size_t i = 0; i < paramCount; ++i) for (size_t i = 0; i < paramCount; ++i)
{ {
const TVariable *param = node->getFunction()->getParam(i); const TVariable *param = node->getFunction()->getParam(i);
OutputTreeText(mOut, node, mDepth + 1); OutputTreeText(mOut, node, getCurrentIndentDepth() + 1);
mOut << "parameter: " << param->name() << " (" << param->getType().getCompleteString() mOut << "parameter: " << param->name() << " (" << param->getType().getCompleteString()
<< ")"; << ")";
} }
...@@ -379,7 +384,7 @@ void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node) ...@@ -379,7 +384,7 @@ void TOutputTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->getOp() == EOpNull) if (node->getOp() == EOpNull)
{ {
...@@ -451,7 +456,7 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -451,7 +456,7 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
bool TOutputTraverser::visitBlock(Visit visit, TIntermBlock *node) bool TOutputTraverser::visitBlock(Visit visit, TIntermBlock *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Code block\n"; mOut << "Code block\n";
return true; return true;
...@@ -459,7 +464,7 @@ bool TOutputTraverser::visitBlock(Visit visit, TIntermBlock *node) ...@@ -459,7 +464,7 @@ bool TOutputTraverser::visitBlock(Visit visit, TIntermBlock *node)
bool TOutputTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node) bool TOutputTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Declaration\n"; mOut << "Declaration\n";
return true; return true;
...@@ -467,18 +472,18 @@ bool TOutputTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node) ...@@ -467,18 +472,18 @@ bool TOutputTraverser::visitDeclaration(Visit visit, TIntermDeclaration *node)
bool TOutputTraverser::visitTernary(Visit visit, TIntermTernary *node) bool TOutputTraverser::visitTernary(Visit visit, TIntermTernary *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Ternary selection"; mOut << "Ternary selection";
mOut << " (" << node->getCompleteString() << ")\n"; mOut << " (" << node->getCompleteString() << ")\n";
++mDepth; ++mIndentDepth;
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Condition\n"; mOut << "Condition\n";
node->getCondition()->traverse(this); node->getCondition()->traverse(this);
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->getTrueExpression()) if (node->getTrueExpression())
{ {
mOut << "true case\n"; mOut << "true case\n";
...@@ -486,29 +491,29 @@ bool TOutputTraverser::visitTernary(Visit visit, TIntermTernary *node) ...@@ -486,29 +491,29 @@ bool TOutputTraverser::visitTernary(Visit visit, TIntermTernary *node)
} }
if (node->getFalseExpression()) if (node->getFalseExpression())
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "false case\n"; mOut << "false case\n";
node->getFalseExpression()->traverse(this); node->getFalseExpression()->traverse(this);
} }
--mDepth; --mIndentDepth;
return false; return false;
} }
bool TOutputTraverser::visitIfElse(Visit visit, TIntermIfElse *node) bool TOutputTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "If test\n"; mOut << "If test\n";
++mDepth; ++mIndentDepth;
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Condition\n"; mOut << "Condition\n";
node->getCondition()->traverse(this); node->getCondition()->traverse(this);
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->getTrueBlock()) if (node->getTrueBlock())
{ {
mOut << "true case\n"; mOut << "true case\n";
...@@ -521,19 +526,19 @@ bool TOutputTraverser::visitIfElse(Visit visit, TIntermIfElse *node) ...@@ -521,19 +526,19 @@ bool TOutputTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
if (node->getFalseBlock()) if (node->getFalseBlock())
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "false case\n"; mOut << "false case\n";
node->getFalseBlock()->traverse(this); node->getFalseBlock()->traverse(this);
} }
--mDepth; --mIndentDepth;
return false; return false;
} }
bool TOutputTraverser::visitSwitch(Visit visit, TIntermSwitch *node) bool TOutputTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Switch\n"; mOut << "Switch\n";
...@@ -542,7 +547,7 @@ bool TOutputTraverser::visitSwitch(Visit visit, TIntermSwitch *node) ...@@ -542,7 +547,7 @@ bool TOutputTraverser::visitSwitch(Visit visit, TIntermSwitch *node)
bool TOutputTraverser::visitCase(Visit visit, TIntermCase *node) bool TOutputTraverser::visitCase(Visit visit, TIntermCase *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->getCondition() == nullptr) if (node->getCondition() == nullptr)
{ {
...@@ -562,7 +567,7 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node) ...@@ -562,7 +567,7 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node)
for (size_t i = 0; i < size; i++) for (size_t i = 0; i < size; i++)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
switch (node->getConstantValue()[i].getType()) switch (node->getConstantValue()[i].getType())
{ {
case EbtBool: case EbtBool:
...@@ -603,16 +608,16 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node) ...@@ -603,16 +608,16 @@ void TOutputTraverser::visitConstantUnion(TIntermConstantUnion *node)
bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node) bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Loop with condition "; mOut << "Loop with condition ";
if (node->getType() == ELoopDoWhile) if (node->getType() == ELoopDoWhile)
mOut << "not "; mOut << "not ";
mOut << "tested first\n"; mOut << "tested first\n";
++mDepth; ++mIndentDepth;
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->getCondition()) if (node->getCondition())
{ {
mOut << "Loop Condition\n"; mOut << "Loop Condition\n";
...@@ -623,7 +628,7 @@ bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node) ...@@ -623,7 +628,7 @@ bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node)
mOut << "No loop condition\n"; mOut << "No loop condition\n";
} }
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
if (node->getBody()) if (node->getBody())
{ {
mOut << "Loop Body\n"; mOut << "Loop Body\n";
...@@ -636,19 +641,19 @@ bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node) ...@@ -636,19 +641,19 @@ bool TOutputTraverser::visitLoop(Visit visit, TIntermLoop *node)
if (node->getExpression()) if (node->getExpression())
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
mOut << "Loop Terminal Expression\n"; mOut << "Loop Terminal Expression\n";
node->getExpression()->traverse(this); node->getExpression()->traverse(this);
} }
--mDepth; --mIndentDepth;
return false; return false;
} }
bool TOutputTraverser::visitBranch(Visit visit, TIntermBranch *node) bool TOutputTraverser::visitBranch(Visit visit, TIntermBranch *node)
{ {
OutputTreeText(mOut, node, mDepth); OutputTreeText(mOut, node, getCurrentIndentDepth());
switch (node->getFlowOp()) switch (node->getFlowOp())
{ {
...@@ -672,9 +677,9 @@ bool TOutputTraverser::visitBranch(Visit visit, TIntermBranch *node) ...@@ -672,9 +677,9 @@ bool TOutputTraverser::visitBranch(Visit visit, TIntermBranch *node)
if (node->getExpression()) if (node->getExpression())
{ {
mOut << " with expression\n"; mOut << " with expression\n";
++mDepth; ++mIndentDepth;
node->getExpression()->traverse(this); node->getExpression()->traverse(this);
--mDepth; --mIndentDepth;
} }
else else
{ {
......
...@@ -14,6 +14,8 @@ namespace sh ...@@ -14,6 +14,8 @@ namespace sh
namespace namespace
{ {
const int kMaxAllowedTraversalDepth = 256;
class ValidateGlobalInitializerTraverser : public TIntermTraverser class ValidateGlobalInitializerTraverser : public TIntermTraverser
{ {
public: public:
...@@ -25,7 +27,7 @@ class ValidateGlobalInitializerTraverser : public TIntermTraverser ...@@ -25,7 +27,7 @@ class ValidateGlobalInitializerTraverser : public TIntermTraverser
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitUnary(Visit visit, TIntermUnary *node) override; bool visitUnary(Visit visit, TIntermUnary *node) override;
bool isValid() const { return mIsValid; } bool isValid() const { return mIsValid && mMaxDepth < mMaxAllowedDepth; }
bool issueWarning() const { return mIssueWarning; } bool issueWarning() const { return mIssueWarning; }
private: private:
...@@ -117,11 +119,12 @@ bool ValidateGlobalInitializerTraverser::visitUnary(Visit visit, TIntermUnary *n ...@@ -117,11 +119,12 @@ bool ValidateGlobalInitializerTraverser::visitUnary(Visit visit, TIntermUnary *n
} }
ValidateGlobalInitializerTraverser::ValidateGlobalInitializerTraverser(int shaderVersion) ValidateGlobalInitializerTraverser::ValidateGlobalInitializerTraverser(int shaderVersion)
: TIntermTraverser(true, false, false), : TIntermTraverser(true, false, false, nullptr),
mShaderVersion(shaderVersion), mShaderVersion(shaderVersion),
mIsValid(true), mIsValid(true),
mIssueWarning(false) mIssueWarning(false)
{ {
setMaxAllowedDepth(kMaxAllowedTraversalDepth);
} }
} // namespace } // namespace
......
...@@ -15,6 +15,8 @@ namespace sh ...@@ -15,6 +15,8 @@ namespace sh
namespace namespace
{ {
const int kMaxAllowedTraversalDepth = 256;
class ValidateSwitch : public TIntermTraverser class ValidateSwitch : public TIntermTraverser
{ {
public: public:
...@@ -69,7 +71,7 @@ bool ValidateSwitch::validate(TBasicType switchType, ...@@ -69,7 +71,7 @@ bool ValidateSwitch::validate(TBasicType switchType,
} }
ValidateSwitch::ValidateSwitch(TBasicType switchType, TDiagnostics *diagnostics) ValidateSwitch::ValidateSwitch(TBasicType switchType, TDiagnostics *diagnostics)
: TIntermTraverser(true, false, true), : TIntermTraverser(true, false, true, nullptr),
mSwitchType(switchType), mSwitchType(switchType),
mDiagnostics(diagnostics), mDiagnostics(diagnostics),
mCaseTypeMismatch(false), mCaseTypeMismatch(false),
...@@ -81,6 +83,7 @@ ValidateSwitch::ValidateSwitch(TBasicType switchType, TDiagnostics *diagnostics) ...@@ -81,6 +83,7 @@ ValidateSwitch::ValidateSwitch(TBasicType switchType, TDiagnostics *diagnostics)
mDefaultCount(0), mDefaultCount(0),
mDuplicateCases(false) mDuplicateCases(false)
{ {
setMaxAllowedDepth(kMaxAllowedTraversalDepth);
} }
void ValidateSwitch::visitSymbol(TIntermSymbol *) void ValidateSwitch::visitSymbol(TIntermSymbol *)
...@@ -290,8 +293,13 @@ bool ValidateSwitch::validateInternal(const TSourceLoc &loc) ...@@ -290,8 +293,13 @@ bool ValidateSwitch::validateInternal(const TSourceLoc &loc)
loc, "no statement between the last label and the end of the switch statement", loc, "no statement between the last label and the end of the switch statement",
"switch"); "switch");
} }
if (getMaxDepth() >= kMaxAllowedTraversalDepth)
{
mDiagnostics->error(loc, "too complex expressions inside a switch statement", "switch");
}
return !mStatementBeforeCase && !mLastStatementWasCase && !mCaseInsideControlFlow && return !mStatementBeforeCase && !mLastStatementWasCase && !mCaseInsideControlFlow &&
!mCaseTypeMismatch && mDefaultCount <= 1 && !mDuplicateCases; !mCaseTypeMismatch && mDefaultCount <= 1 && !mDuplicateCases &&
getMaxDepth() < kMaxAllowedTraversalDepth;
} }
} // anonymous namespace } // anonymous namespace
......
...@@ -110,8 +110,8 @@ TIntermTraverser::TIntermTraverser(bool preVisit, ...@@ -110,8 +110,8 @@ TIntermTraverser::TIntermTraverser(bool preVisit,
: preVisit(preVisit), : preVisit(preVisit),
inVisit(inVisit), inVisit(inVisit),
postVisit(postVisit), postVisit(postVisit),
mDepth(-1),
mMaxDepth(0), mMaxDepth(0),
mMaxAllowedDepth(std::numeric_limits<int>::max()),
mInGlobalScope(true), mInGlobalScope(true),
mSymbolTable(symbolTable) mSymbolTable(symbolTable)
{ {
...@@ -121,6 +121,11 @@ TIntermTraverser::~TIntermTraverser() ...@@ -121,6 +121,11 @@ TIntermTraverser::~TIntermTraverser()
{ {
} }
void TIntermTraverser::setMaxAllowedDepth(int depth)
{
mMaxAllowedDepth = depth;
}
const TIntermBlock *TIntermTraverser::getParentBlock() const const TIntermBlock *TIntermTraverser::getParentBlock() const
{ {
if (!mParentBlockStack.empty()) if (!mParentBlockStack.empty())
...@@ -215,6 +220,8 @@ void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node) ...@@ -215,6 +220,8 @@ void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node)
void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node) void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -236,6 +243,8 @@ void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node) ...@@ -236,6 +243,8 @@ void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
void TIntermTraverser::traverseBinary(TIntermBinary *node) void TIntermTraverser::traverseBinary(TIntermBinary *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -271,6 +280,8 @@ void TIntermTraverser::traverseBinary(TIntermBinary *node) ...@@ -271,6 +280,8 @@ void TIntermTraverser::traverseBinary(TIntermBinary *node)
void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node) void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -335,6 +346,8 @@ void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node) ...@@ -335,6 +346,8 @@ void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
void TIntermTraverser::traverseUnary(TIntermUnary *node) void TIntermTraverser::traverseUnary(TIntermUnary *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -353,6 +366,8 @@ void TIntermTraverser::traverseUnary(TIntermUnary *node) ...@@ -353,6 +366,8 @@ void TIntermTraverser::traverseUnary(TIntermUnary *node)
void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node) void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -387,6 +402,8 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node) ...@@ -387,6 +402,8 @@ void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node) void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -413,6 +430,9 @@ void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *nod ...@@ -413,6 +430,9 @@ void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *nod
void TIntermTraverser::traverseBlock(TIntermBlock *node) void TIntermTraverser::traverseBlock(TIntermBlock *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
pushParentBlock(node); pushParentBlock(node);
bool visit = true; bool visit = true;
...@@ -446,6 +466,8 @@ void TIntermTraverser::traverseBlock(TIntermBlock *node) ...@@ -446,6 +466,8 @@ void TIntermTraverser::traverseBlock(TIntermBlock *node)
void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration *node) void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -468,6 +490,8 @@ void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration ...@@ -468,6 +490,8 @@ void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration
void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node) void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -496,6 +520,7 @@ void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node) ...@@ -496,6 +520,7 @@ void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node) void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
visitFunctionPrototype(node); visitFunctionPrototype(node);
} }
...@@ -503,6 +528,8 @@ void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node) ...@@ -503,6 +528,8 @@ void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
void TIntermTraverser::traverseAggregate(TIntermAggregate *node) void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -633,6 +660,8 @@ TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit, ...@@ -633,6 +660,8 @@ TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisit,
void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -680,6 +709,8 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -680,6 +709,8 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
void TIntermTraverser::traverseTernary(TIntermTernary *node) void TIntermTraverser::traverseTernary(TIntermTernary *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -703,6 +734,8 @@ void TIntermTraverser::traverseTernary(TIntermTernary *node) ...@@ -703,6 +734,8 @@ void TIntermTraverser::traverseTernary(TIntermTernary *node)
void TIntermTraverser::traverseIfElse(TIntermIfElse *node) void TIntermTraverser::traverseIfElse(TIntermIfElse *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -728,6 +761,8 @@ void TIntermTraverser::traverseIfElse(TIntermIfElse *node) ...@@ -728,6 +761,8 @@ void TIntermTraverser::traverseIfElse(TIntermIfElse *node)
void TIntermTraverser::traverseSwitch(TIntermSwitch *node) void TIntermTraverser::traverseSwitch(TIntermSwitch *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -753,6 +788,8 @@ void TIntermTraverser::traverseSwitch(TIntermSwitch *node) ...@@ -753,6 +788,8 @@ void TIntermTraverser::traverseSwitch(TIntermSwitch *node)
void TIntermTraverser::traverseCase(TIntermCase *node) void TIntermTraverser::traverseCase(TIntermCase *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -774,6 +811,8 @@ void TIntermTraverser::traverseCase(TIntermCase *node) ...@@ -774,6 +811,8 @@ void TIntermTraverser::traverseCase(TIntermCase *node)
void TIntermTraverser::traverseLoop(TIntermLoop *node) void TIntermTraverser::traverseLoop(TIntermLoop *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
...@@ -805,6 +844,8 @@ void TIntermTraverser::traverseLoop(TIntermLoop *node) ...@@ -805,6 +844,8 @@ void TIntermTraverser::traverseLoop(TIntermLoop *node)
void TIntermTraverser::traverseBranch(TIntermBranch *node) void TIntermTraverser::traverseBranch(TIntermBranch *node)
{ {
ScopedNodeInTraversalPath addToPath(this, node); ScopedNodeInTraversalPath addToPath(this, node);
if (!addToPath.isWithinDepthLimit())
return;
bool visit = true; bool visit = true;
......
...@@ -98,21 +98,24 @@ class TIntermTraverser : angle::NonCopyable ...@@ -98,21 +98,24 @@ class TIntermTraverser : angle::NonCopyable
void updateTree(); void updateTree();
protected: protected:
void setMaxAllowedDepth(int depth);
// Should only be called from traverse*() functions // Should only be called from traverse*() functions
void incrementDepth(TIntermNode *current) bool incrementDepth(TIntermNode *current)
{ {
mDepth++; mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size()));
mMaxDepth = std::max(mMaxDepth, mDepth);
mPath.push_back(current); mPath.push_back(current);
return mMaxDepth < mMaxAllowedDepth;
} }
// Should only be called from traverse*() functions // Should only be called from traverse*() functions
void decrementDepth() void decrementDepth()
{ {
mDepth--;
mPath.pop_back(); mPath.pop_back();
} }
int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; }
// RAII helper for incrementDepth/decrementDepth // RAII helper for incrementDepth/decrementDepth
class ScopedNodeInTraversalPath class ScopedNodeInTraversalPath
{ {
...@@ -120,12 +123,15 @@ class TIntermTraverser : angle::NonCopyable ...@@ -120,12 +123,15 @@ class TIntermTraverser : angle::NonCopyable
ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current) ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
: mTraverser(traverser) : mTraverser(traverser)
{ {
mTraverser->incrementDepth(current); mWithinDepthLimit = mTraverser->incrementDepth(current);
} }
~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); } ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
bool isWithinDepthLimit() { return mWithinDepthLimit; }
private: private:
TIntermTraverser *mTraverser; TIntermTraverser *mTraverser;
bool mWithinDepthLimit;
}; };
TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; } TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
...@@ -196,8 +202,8 @@ class TIntermTraverser : angle::NonCopyable ...@@ -196,8 +202,8 @@ class TIntermTraverser : angle::NonCopyable
const bool inVisit; const bool inVisit;
const bool postVisit; const bool postVisit;
int mDepth;
int mMaxDepth; int mMaxDepth;
int mMaxAllowedDepth;
bool mInGlobalScope; bool mInGlobalScope;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment