Commit fa63e947 by Zhenyao Mo

Improvement on loop unrolling with loops indexing sampler arrays

1) Before this workaround is hardwired on mac, now we move it behind a compil 2) Fix the issue where "break" inside the loop isn't handled while unrolled. BUG=338474 TEST=webgl conformance test sampler-array-using-loop-index.html Change-Id: I4996a42c2dea39a8a5af772c256f8e3cb383f59a Reviewed-on: https://chromium-review.googlesource.com/188079Reviewed-by: 's avatarZhenyao Mo <zmo@chromium.org> Tested-by: 's avatarZhenyao Mo <zmo@chromium.org>
parent b75fee4e
...@@ -142,9 +142,14 @@ typedef enum { ...@@ -142,9 +142,14 @@ typedef enum {
SH_SOURCE_PATH = 0x0020, SH_SOURCE_PATH = 0x0020,
SH_MAP_LONG_VARIABLE_NAMES = 0x0040, SH_MAP_LONG_VARIABLE_NAMES = 0x0040,
SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX = 0x0080, SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX = 0x0080,
// If a sampler array index happens to be a loop index,
// 1) if its type is integer, unroll the loop.
// 2) if its type is float, fail the shader compile.
// This is to work around a mac driver bug.
SH_UNROLL_FOR_LOOP_WITH_SAMPLER_ARRAY_INDEX = 0x0100,
// This is needed only as a workaround for certain OpenGL driver bugs. // This is needed only as a workaround for certain OpenGL driver bugs.
SH_EMULATE_BUILT_IN_FUNCTIONS = 0x0100, SH_EMULATE_BUILT_IN_FUNCTIONS = 0x0200,
// This is an experimental flag to enforce restrictions that aim to prevent // This is an experimental flag to enforce restrictions that aim to prevent
// timing attacks. // timing attacks.
...@@ -152,7 +157,7 @@ typedef enum { ...@@ -152,7 +157,7 @@ typedef enum {
// texture information via the timing channel. // texture information via the timing channel.
// To use this flag, you must compile the shader under the WebGL spec // To use this flag, you must compile the shader under the WebGL spec
// (using the SH_WEBGL_SPEC flag). // (using the SH_WEBGL_SPEC flag).
SH_TIMING_RESTRICTIONS = 0x0200, SH_TIMING_RESTRICTIONS = 0x0400,
// This flag prints the dependency graph that is used to enforce timing // This flag prints the dependency graph that is used to enforce timing
// restrictions on fragment shaders. // restrictions on fragment shaders.
...@@ -160,7 +165,7 @@ typedef enum { ...@@ -160,7 +165,7 @@ typedef enum {
// - The shader spec is SH_WEBGL_SPEC. // - The shader spec is SH_WEBGL_SPEC.
// - The compile options contain the SH_TIMING_RESTRICTIONS flag. // - The compile options contain the SH_TIMING_RESTRICTIONS flag.
// - The shader type is SH_FRAGMENT_SHADER. // - The shader type is SH_FRAGMENT_SHADER.
SH_DEPENDENCY_GRAPH = 0x0400, SH_DEPENDENCY_GRAPH = 0x0800,
// Enforce the GLSL 1.017 Appendix A section 7 packing restrictions. // Enforce the GLSL 1.017 Appendix A section 7 packing restrictions.
// This flag only enforces (and can only enforce) the packing // This flag only enforces (and can only enforce) the packing
...@@ -168,7 +173,7 @@ typedef enum { ...@@ -168,7 +173,7 @@ typedef enum {
// shaders. ShCheckVariablesWithinPackingLimits() lets embedders // shaders. ShCheckVariablesWithinPackingLimits() lets embedders
// enforce the packing restrictions for varying variables during // enforce the packing restrictions for varying variables during
// program link time. // program link time.
SH_ENFORCE_PACKING_RESTRICTIONS = 0x0800, SH_ENFORCE_PACKING_RESTRICTIONS = 0x1000,
// This flag ensures all indirect (expression-based) array indexing // This flag ensures all indirect (expression-based) array indexing
// is clamped to the bounds of the array. This ensures, for example, // is clamped to the bounds of the array. This ensures, for example,
...@@ -176,32 +181,32 @@ typedef enum { ...@@ -176,32 +181,32 @@ typedef enum {
// vec234, or mat234 type. The ShArrayIndexClampingStrategy enum, // vec234, or mat234 type. The ShArrayIndexClampingStrategy enum,
// specified in the ShBuiltInResources when constructing the // specified in the ShBuiltInResources when constructing the
// compiler, selects the strategy for the clamping implementation. // compiler, selects the strategy for the clamping implementation.
SH_CLAMP_INDIRECT_ARRAY_BOUNDS = 0x1000, SH_CLAMP_INDIRECT_ARRAY_BOUNDS = 0x2000,
// This flag limits the complexity of an expression. // This flag limits the complexity of an expression.
SH_LIMIT_EXPRESSION_COMPLEXITY = 0x2000, SH_LIMIT_EXPRESSION_COMPLEXITY = 0x4000,
// This flag limits the depth of the call stack. // This flag limits the depth of the call stack.
SH_LIMIT_CALL_STACK_DEPTH = 0x4000, SH_LIMIT_CALL_STACK_DEPTH = 0x8000,
// This flag initializes gl_Position to vec4(0,0,0,0) at the // This flag initializes gl_Position to vec4(0,0,0,0) at the
// beginning of the vertex shader's main(), and has no effect in the // beginning of the vertex shader's main(), and has no effect in the
// fragment shader. It is intended as a workaround for drivers which // fragment shader. It is intended as a workaround for drivers which
// incorrectly fail to link programs if gl_Position is not written. // incorrectly fail to link programs if gl_Position is not written.
SH_INIT_GL_POSITION = 0x8000, SH_INIT_GL_POSITION = 0x10000,
// This flag replaces // This flag replaces
// "a && b" with "a ? b : false", // "a && b" with "a ? b : false",
// "a || b" with "a ? true : b". // "a || b" with "a ? true : b".
// This is to work around a MacOSX driver bug that |b| is executed // This is to work around a MacOSX driver bug that |b| is executed
// independent of |a|'s value. // independent of |a|'s value.
SH_UNFOLD_SHORT_CIRCUIT = 0x10000, SH_UNFOLD_SHORT_CIRCUIT = 0x20000,
// This flag initializes varyings without static use in vertex shader // This flag initializes varyings without static use in vertex shader
// at the beginning of main(), and has no effects in the fragment shader. // at the beginning of main(), and has no effects in the fragment shader.
// It is intended as a workaround for drivers which incorrectly optimize // It is intended as a workaround for drivers which incorrectly optimize
// out such varyings and cause a link failure. // out such varyings and cause a link failure.
SH_INIT_VARYINGS_WITHOUT_STATIC_USE = 0x20000, SH_INIT_VARYINGS_WITHOUT_STATIC_USE = 0x40000,
} ShCompileOptions; } ShCompileOptions;
// Defines alternate strategies for implementing array index clamping. // Defines alternate strategies for implementing array index clamping.
......
...@@ -178,7 +178,21 @@ bool TCompiler::compile(const char* const shaderStrings[], ...@@ -178,7 +178,21 @@ bool TCompiler::compile(const char* const shaderStrings[],
// Unroll for-loop markup needs to happen after validateLimitations pass. // Unroll for-loop markup needs to happen after validateLimitations pass.
if (success && (compileOptions & SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX)) if (success && (compileOptions & SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX))
ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling(root); {
ForLoopUnrollMarker marker(ForLoopUnrollMarker::kIntegerIndex);
root->traverse(&marker);
}
if (success && (compileOptions & SH_UNROLL_FOR_LOOP_WITH_SAMPLER_ARRAY_INDEX))
{
ForLoopUnrollMarker marker(ForLoopUnrollMarker::kSamplerArrayIndex);
root->traverse(&marker);
if (marker.samplerArrayIndexIsFloatLoopIndex())
{
infoSink.info.prefix(EPrefixError);
infoSink.info << "sampler array index is float loop index";
success = false;
}
}
// Built-in function emulation needs to happen after validateLimitations pass. // Built-in function emulation needs to happen after validateLimitations pass.
if (success && (compileOptions & SH_EMULATE_BUILT_IN_FUNCTIONS)) if (success && (compileOptions & SH_EMULATE_BUILT_IN_FUNCTIONS))
......
...@@ -6,210 +6,77 @@ ...@@ -6,210 +6,77 @@
#include "compiler/translator/ForLoopUnroll.h" #include "compiler/translator/ForLoopUnroll.h"
namespace { bool ForLoopUnrollMarker::visitBinary(Visit, TIntermBinary *node)
{
class IntegerForLoopUnrollMarker : public TIntermTraverser { if (mUnrollCondition != kSamplerArrayIndex)
public: return true;
virtual bool visitLoop(Visit, TIntermLoop* node) // If a sampler array index is also the loop index,
// 1) if the index type is integer, mark the loop for unrolling;
// 2) if the index type if float, set a flag to later fail compile.
switch (node->getOp())
{ {
// This is called after ValidateLimitations pass, so all the ASSERT case EOpIndexIndirect:
// should never fail. if (node->getLeft() != NULL && node->getRight() != NULL && node->getLeft()->getAsSymbolNode())
// See ValidateLimitations::validateForLoopInit(). {
ASSERT(node); TIntermSymbol *symbol = node->getLeft()->getAsSymbolNode();
ASSERT(node->getType() == ELoopFor); if (IsSampler(symbol->getBasicType()) && symbol->isArray() && !mLoopStack.empty())
ASSERT(node->getInit()); {
TIntermAggregate* decl = node->getInit()->getAsAggregate(); mVisitSamplerArrayIndexNodeInsideLoop = true;
ASSERT(decl && decl->getOp() == EOpDeclaration); node->getRight()->traverse(this);
TIntermSequence& declSeq = decl->getSequence(); mVisitSamplerArrayIndexNodeInsideLoop = false;
ASSERT(declSeq.size() == 1); // We have already visited all the children.
TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); return false;
ASSERT(declInit && declInit->getOp() == EOpInitialize);
ASSERT(declInit->getLeft());
TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
ASSERT(symbol);
TBasicType type = symbol->getBasicType();
ASSERT(type == EbtInt || type == EbtFloat);
if (type == EbtInt)
node->setUnrollFlag(true);
return true;
} }
}
}; break;
} // anonymous namepsace
void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
{
ASSERT(node->getType() == ELoopFor);
ASSERT(node->getUnrollFlag());
TIntermNode* init = node->getInit();
ASSERT(init != NULL);
TIntermAggregate* decl = init->getAsAggregate();
ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
TIntermSequence& declSeq = decl->getSequence();
ASSERT(declSeq.size() == 1);
TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
ASSERT(symbol != NULL);
ASSERT(symbol->getBasicType() == EbtInt);
info.id = symbol->getId();
ASSERT(declInit->getRight() != NULL);
TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
ASSERT(initNode != NULL);
info.initValue = evaluateIntConstant(initNode);
info.currentValue = info.initValue;
TIntermNode* cond = node->getCondition();
ASSERT(cond != NULL);
TIntermBinary* binOp = cond->getAsBinaryNode();
ASSERT(binOp != NULL);
ASSERT(binOp->getRight() != NULL);
ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
info.incrementValue = getLoopIncrement(node);
info.stopValue = evaluateIntConstant(
binOp->getRight()->getAsConstantUnion());
info.op = binOp->getOp();
}
void ForLoopUnroll::Step()
{
ASSERT(mLoopIndexStack.size() > 0);
TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
info.currentValue += info.incrementValue;
}
bool ForLoopUnroll::SatisfiesLoopCondition()
{
ASSERT(mLoopIndexStack.size() > 0);
TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
// Relational operator is one of: > >= < <= == or !=.
switch (info.op) {
case EOpEqual:
return (info.currentValue == info.stopValue);
case EOpNotEqual:
return (info.currentValue != info.stopValue);
case EOpLessThan:
return (info.currentValue < info.stopValue);
case EOpGreaterThan:
return (info.currentValue > info.stopValue);
case EOpLessThanEqual:
return (info.currentValue <= info.stopValue);
case EOpGreaterThanEqual:
return (info.currentValue >= info.stopValue);
default: default:
UNREACHABLE(); break;
} }
return false; return true;
} }
bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol) bool ForLoopUnrollMarker::visitLoop(Visit, TIntermLoop *node)
{ {
for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin(); if (mUnrollCondition == kIntegerIndex)
i != mLoopIndexStack.end(); {
++i) { // Check if loop index type is integer.
if (i->id == symbol->getId()) // This is called after ValidateLimitations pass, so all the calls
return true; // should be valid. See ValidateLimitations::validateForLoopInit().
TIntermSequence& declSeq = node->getInit()->getAsAggregate()->getSequence();
TIntermSymbol* symbol = declSeq[0]->getAsBinaryNode()->getLeft()->getAsSymbolNode();
if (symbol->getBasicType() == EbtInt)
node->setUnrollFlag(true);
} }
return false;
}
int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol) TIntermNode *body = node->getBody();
{ if (body != NULL)
for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin(); {
i != mLoopIndexStack.end(); mLoopStack.push(node);
++i) { body->traverse(this);
if (i->id == symbol->getId()) mLoopStack.pop();
return i->currentValue;
} }
UNREACHABLE(); // The loop is fully processed - no need to visit children.
return false; return false;
} }
void ForLoopUnroll::Push(TLoopIndexInfo& info) void ForLoopUnrollMarker::visitSymbol(TIntermSymbol* symbol)
{
mLoopIndexStack.push_back(info);
}
void ForLoopUnroll::Pop()
{
mLoopIndexStack.pop_back();
}
// static
void ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling(
TIntermNode* root)
{ {
ASSERT(root); if (!mVisitSamplerArrayIndexNodeInsideLoop)
return;
IntegerForLoopUnrollMarker marker; TIntermLoop *loop = mLoopStack.findLoop(symbol);
root->traverse(&marker); if (loop)
} {
switch (symbol->getBasicType())
int ForLoopUnroll::getLoopIncrement(TIntermLoop* node) {
{ case EbtFloat:
TIntermNode* expr = node->getExpression(); mSamplerArrayIndexIsFloatLoopIndex = true;
ASSERT(expr != NULL);
// for expression has one of the following forms:
// loop_index++
// loop_index--
// loop_index += constant_expression
// loop_index -= constant_expression
// ++loop_index
// --loop_index
// The last two forms are not specified in the spec, but I am assuming
// its an oversight.
TIntermUnary* unOp = expr->getAsUnaryNode();
TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
TOperator op = EOpNull;
TIntermConstantUnion* incrementNode = NULL;
if (unOp != NULL) {
op = unOp->getOp();
} else if (binOp != NULL) {
op = binOp->getOp();
ASSERT(binOp->getRight() != NULL);
incrementNode = binOp->getRight()->getAsConstantUnion();
ASSERT(incrementNode != NULL);
}
int increment = 0;
// The operator is one of: ++ -- += -=.
switch (op) {
case EOpPostIncrement:
case EOpPreIncrement:
ASSERT((unOp != NULL) && (binOp == NULL));
increment = 1;
break;
case EOpPostDecrement:
case EOpPreDecrement:
ASSERT((unOp != NULL) && (binOp == NULL));
increment = -1;
break;
case EOpAddAssign:
ASSERT((unOp == NULL) && (binOp != NULL));
increment = evaluateIntConstant(incrementNode);
break; break;
case EOpSubAssign: case EbtInt:
ASSERT((unOp == NULL) && (binOp != NULL)); loop->setUnrollFlag(true);
increment = - evaluateIntConstant(incrementNode);
break; break;
default: default:
ASSERT(false); UNREACHABLE();
}
} }
return increment;
}
int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
{
ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
return node->getIConst(0);
} }
...@@ -7,46 +7,44 @@ ...@@ -7,46 +7,44 @@
#ifndef COMPILER_FORLOOPUNROLL_H_ #ifndef COMPILER_FORLOOPUNROLL_H_
#define COMPILER_FORLOOPUNROLL_H_ #define COMPILER_FORLOOPUNROLL_H_
#include "compiler/translator/intermediate.h" #include "compiler/translator/LoopInfo.h"
struct TLoopIndexInfo { // This class detects for-loops that needs to be unrolled.
int id; // Currently we support two unroll conditions:
int initValue; // 1) kForLoopWithIntegerIndex: unroll if the index type is integer.
int stopValue; // 2) kForLoopWithSamplerArrayIndex: unroll where a sampler array index
int incrementValue; // is also the loop integer index, and reject and fail a compile
TOperator op; // where a sampler array index is also the loop float index.
int currentValue; class ForLoopUnrollMarker : public TIntermTraverser
}; {
public:
class ForLoopUnroll { enum UnrollCondition
public: {
ForLoopUnroll() { } kIntegerIndex,
kSamplerArrayIndex
void FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info); };
// Update the info.currentValue for the next loop iteration. ForLoopUnrollMarker(UnrollCondition condition)
void Step(); : mUnrollCondition(condition),
mSamplerArrayIndexIsFloatLoopIndex(false),
// Return false if loop condition is no longer satisfied. mVisitSamplerArrayIndexNodeInsideLoop(false)
bool SatisfiesLoopCondition(); {
}
// Check if the symbol is the index of a loop that's unrolled.
bool NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol); virtual bool visitBinary(Visit, TIntermBinary *node);
virtual bool visitLoop(Visit, TIntermLoop *node);
// Return the current value of a given loop index symbol. virtual void visitSymbol(TIntermSymbol *node);
int GetLoopIndexValue(TIntermSymbol* symbol);
bool samplerArrayIndexIsFloatLoopIndex() const
void Push(TLoopIndexInfo& info); {
void Pop(); return mSamplerArrayIndexIsFloatLoopIndex;
}
static void MarkForLoopsWithIntegerIndicesForUnrolling(TIntermNode* root);
private:
private: UnrollCondition mUnrollCondition;
int getLoopIncrement(TIntermLoop* node); TLoopStack mLoopStack;
bool mSamplerArrayIndexIsFloatLoopIndex;
int evaluateIntConstant(TIntermConstantUnion* node); bool mVisitSamplerArrayIndexNodeInsideLoop;
TVector<TLoopIndexInfo> mLoopIndexStack;
}; };
#endif #endif
//
// Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/translator/LoopInfo.h"
namespace
{
int EvaluateIntConstant(TIntermConstantUnion *node)
{
ASSERT(node && node->getUnionArrayPointer());
return node->getIConst(0);
}
int GetLoopIntIncrement(TIntermLoop *node)
{
TIntermNode *expr = node->getExpression();
// for expression has one of the following forms:
// loop_index++
// loop_index--
// loop_index += constant_expression
// loop_index -= constant_expression
// ++loop_index
// --loop_index
// The last two forms are not specified in the spec, but I am assuming
// its an oversight.
TIntermUnary *unOp = expr->getAsUnaryNode();
TIntermBinary *binOp = unOp ? NULL : expr->getAsBinaryNode();
TOperator op = EOpNull;
TIntermConstantUnion *incrementNode = NULL;
if (unOp)
{
op = unOp->getOp();
}
else if (binOp)
{
op = binOp->getOp();
ASSERT(binOp->getRight());
incrementNode = binOp->getRight()->getAsConstantUnion();
ASSERT(incrementNode);
}
int increment = 0;
// The operator is one of: ++ -- += -=.
switch (op)
{
case EOpPostIncrement:
case EOpPreIncrement:
ASSERT(unOp && !binOp);
increment = 1;
break;
case EOpPostDecrement:
case EOpPreDecrement:
ASSERT(unOp && !binOp);
increment = -1;
break;
case EOpAddAssign:
ASSERT(!unOp && binOp);
increment = EvaluateIntConstant(incrementNode);
break;
case EOpSubAssign:
ASSERT(!unOp && binOp);
increment = - EvaluateIntConstant(incrementNode);
break;
default:
UNREACHABLE();
}
return increment;
}
} // namespace anonymous
TLoopIndexInfo::TLoopIndexInfo()
: mId(-1),
mType(EbtVoid),
mInitValue(0),
mStopValue(0),
mIncrementValue(0),
mOp(EOpNull),
mCurrentValue(0)
{
}
void TLoopIndexInfo::fillInfo(TIntermLoop *node)
{
if (node == NULL)
return;
// Here we assume all the operations are valid, because the loop node is
// already validated in ValidateLimitations.
TIntermSequence &declSeq =
node->getInit()->getAsAggregate()->getSequence();
TIntermBinary *declInit = declSeq[0]->getAsBinaryNode();
TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
mId = symbol->getId();
mType = symbol->getBasicType();
if (mType == EbtInt)
{
TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
mInitValue = EvaluateIntConstant(initNode);
mCurrentValue = mInitValue;
mIncrementValue = GetLoopIntIncrement(node);
TIntermBinary* binOp = node->getCondition()->getAsBinaryNode();
mStopValue = EvaluateIntConstant(
binOp->getRight()->getAsConstantUnion());
mOp = binOp->getOp();
}
}
bool TLoopIndexInfo::satisfiesLoopCondition() const
{
// Relational operator is one of: > >= < <= == or !=.
switch (mOp)
{
case EOpEqual:
return (mCurrentValue == mStopValue);
case EOpNotEqual:
return (mCurrentValue != mStopValue);
case EOpLessThan:
return (mCurrentValue < mStopValue);
case EOpGreaterThan:
return (mCurrentValue > mStopValue);
case EOpLessThanEqual:
return (mCurrentValue <= mStopValue);
case EOpGreaterThanEqual:
return (mCurrentValue >= mStopValue);
default:
UNREACHABLE();
return false;
}
}
TLoopInfo::TLoopInfo()
: loop(NULL)
{
}
TLoopInfo::TLoopInfo(TIntermLoop *node)
: loop(node)
{
index.fillInfo(node);
}
TIntermLoop *TLoopStack::findLoop(TIntermSymbol *symbol)
{
if (!symbol)
return NULL;
for (iterator iter = begin(); iter != end(); ++iter)
{
if (iter->index.getId() == symbol->getId())
return iter->loop;
}
return NULL;
}
TLoopIndexInfo *TLoopStack::getIndexInfo(TIntermSymbol *symbol)
{
if (!symbol)
return NULL;
for (iterator iter = begin(); iter != end(); ++iter)
{
if (iter->index.getId() == symbol->getId())
return &(iter->index);
}
return NULL;
}
void TLoopStack::step()
{
ASSERT(!empty());
rbegin()->index.step();
}
bool TLoopStack::satisfiesLoopCondition()
{
ASSERT(!empty());
return rbegin()->index.satisfiesLoopCondition();
}
bool TLoopStack::needsToReplaceSymbolWithValue(TIntermSymbol *symbol)
{
TIntermLoop *loop = findLoop(symbol);
return loop && loop->getUnrollFlag();
}
int TLoopStack::getLoopIndexValue(TIntermSymbol *symbol)
{
TLoopIndexInfo *info = getIndexInfo(symbol);
ASSERT(info);
return info->getCurrentValue();
}
void TLoopStack::push(TIntermLoop *loop)
{
TLoopInfo info(loop);
push_back(info);
}
void TLoopStack::pop()
{
pop_back();
}
//
// Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#ifndef COMPILER_TRANSLATOR_LOOP_INFO_H_
#define COMPILER_TRANSLATOR_LOOP_INFO_H_
#include "compiler/translator/intermediate.h"
class TLoopIndexInfo
{
public:
TLoopIndexInfo();
// If type is EbtInt, fill all fields of the structure with info
// extracted from a loop node.
// If type is not EbtInt, only fill id and type.
void fillInfo(TIntermLoop *node);
int getId() const { return mId; }
void setId(int id) { mId = id; }
TBasicType getType() const { return mType; }
void setType(TBasicType type) { mType = type; }
int getCurrentValue() const { return mCurrentValue; }
void step() { mCurrentValue += mIncrementValue; }
// Check if the current value satisfies the loop condition.
bool satisfiesLoopCondition() const;
private:
int mId;
TBasicType mType; // Either EbtInt or EbtFloat
// Below fields are only valid if the index's type is int.
int mInitValue;
int mStopValue;
int mIncrementValue;
TOperator mOp;
int mCurrentValue;
};
struct TLoopInfo
{
TLoopIndexInfo index;
TIntermLoop *loop;
TLoopInfo();
TLoopInfo(TIntermLoop *node);
};
class TLoopStack : public TVector<TLoopInfo>
{
public:
// Search loop stack for a loop whose index matches the input symbol.
TIntermLoop *findLoop(TIntermSymbol *symbol);
// Find the loop index info in the loop stack by the input symbol.
TLoopIndexInfo *getIndexInfo(TIntermSymbol *symbol);
// Update the currentValue for the next loop iteration.
void step();
// Return false if loop condition is no longer satisfied.
bool satisfiesLoopCondition();
// Check if the symbol is the index of a loop that's unrolled.
bool needsToReplaceSymbolWithValue(TIntermSymbol *symbol);
// Return the current value of a given loop index symbol.
int getLoopIndexValue(TIntermSymbol *symbol);
void push(TIntermLoop *info);
void pop();
};
#endif // COMPILER_TRANSLATOR_LOOP_INDEX_H_
...@@ -159,8 +159,8 @@ const ConstantUnion* TOutputGLSLBase::writeConstantUnion(const TType& type, ...@@ -159,8 +159,8 @@ const ConstantUnion* TOutputGLSLBase::writeConstantUnion(const TType& type,
void TOutputGLSLBase::visitSymbol(TIntermSymbol* node) void TOutputGLSLBase::visitSymbol(TIntermSymbol* node)
{ {
TInfoSinkBase& out = objSink(); TInfoSinkBase& out = objSink();
if (mLoopUnroll.NeedsToReplaceSymbolWithValue(node)) if (mLoopUnrollStack.needsToReplaceSymbolWithValue(node))
out << mLoopUnroll.GetLoopIndexValue(node); out << mLoopUnrollStack.getLoopIndexValue(node);
else else
out << hashVariableName(node->getSymbol()); out << hashVariableName(node->getSymbol());
...@@ -643,7 +643,8 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node) ...@@ -643,7 +643,8 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node)
TLoopType loopType = node->getType(); TLoopType loopType = node->getType();
if (loopType == ELoopFor) // for loop if (loopType == ELoopFor) // for loop
{ {
if (!node->getUnrollFlag()) { if (!node->getUnrollFlag())
{
out << "for ("; out << "for (";
if (node->getInit()) if (node->getInit())
node->getInit()->traverse(this); node->getInit()->traverse(this);
...@@ -657,6 +658,18 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node) ...@@ -657,6 +658,18 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node)
node->getExpression()->traverse(this); node->getExpression()->traverse(this);
out << ")\n"; out << ")\n";
} }
else
{
// Need to put a one-iteration loop here to handle break.
TIntermSequence &declSeq =
node->getInit()->getAsAggregate()->getSequence();
TIntermSymbol *indexSymbol =
declSeq[0]->getAsBinaryNode()->getLeft()->getAsSymbolNode();
TString name = hashVariableName(indexSymbol->getSymbol());
out << "for (int " << name << " = 0; "
<< name << " < 1; "
<< "++" << name << ")\n";
}
} }
else if (loopType == ELoopWhile) // while loop else if (loopType == ELoopWhile) // while loop
{ {
...@@ -674,15 +687,15 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node) ...@@ -674,15 +687,15 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node)
// Loop body. // Loop body.
if (node->getUnrollFlag()) if (node->getUnrollFlag())
{ {
TLoopIndexInfo indexInfo; out << "{\n";
mLoopUnroll.FillLoopIndexInfo(node, indexInfo); mLoopUnrollStack.push(node);
mLoopUnroll.Push(indexInfo); while (mLoopUnrollStack.satisfiesLoopCondition())
while (mLoopUnroll.SatisfiesLoopCondition())
{ {
visitCodeBlock(node->getBody()); visitCodeBlock(node->getBody());
mLoopUnroll.Step(); mLoopUnrollStack.step();
} }
mLoopUnroll.Pop(); mLoopUnrollStack.pop();
out << "}\n";
} }
else else
{ {
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include <set> #include <set>
#include "compiler/translator/ForLoopUnroll.h"
#include "compiler/translator/intermediate.h" #include "compiler/translator/intermediate.h"
#include "compiler/translator/LoopInfo.h"
#include "compiler/translator/ParseContext.h" #include "compiler/translator/ParseContext.h"
class TOutputGLSLBase : public TIntermTraverser class TOutputGLSLBase : public TIntermTraverser
...@@ -64,7 +64,8 @@ private: ...@@ -64,7 +64,8 @@ private:
typedef std::set<TString> DeclaredStructs; typedef std::set<TString> DeclaredStructs;
DeclaredStructs mDeclaredStructs; DeclaredStructs mDeclaredStructs;
ForLoopUnroll mLoopUnroll; // Stack of loops that need to be unrolled.
TLoopStack mLoopUnrollStack;
ShArrayIndexClampingStrategy mClampingStrategy; ShArrayIndexClampingStrategy mClampingStrategy;
......
...@@ -9,25 +9,8 @@ ...@@ -9,25 +9,8 @@
#include "compiler/translator/InitializeParseContext.h" #include "compiler/translator/InitializeParseContext.h"
#include "compiler/translator/ParseContext.h" #include "compiler/translator/ParseContext.h"
namespace { namespace
bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) { {
for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
if (i->index.id == symbol->getId())
return true;
}
return false;
}
void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
if (i->index.id == symbol->getId()) {
ASSERT(i->loop != NULL);
i->loop->setUnrollFlag(true);
return;
}
}
UNREACHABLE();
}
// Traverses a node to check if it represents a constant index expression. // Traverses a node to check if it represents a constant index expression.
// Definition: // Definition:
...@@ -38,110 +21,60 @@ void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) { ...@@ -38,110 +21,60 @@ void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
// - Constant expressions // - Constant expressions
// - Loop indices as defined in section 4 // - Loop indices as defined in section 4
// - Expressions composed of both of the above // - Expressions composed of both of the above
class ValidateConstIndexExpr : public TIntermTraverser { class ValidateConstIndexExpr : public TIntermTraverser
public: {
ValidateConstIndexExpr(const TLoopStack& stack) public:
ValidateConstIndexExpr(TLoopStack& stack)
: mValid(true), mLoopStack(stack) {} : mValid(true), mLoopStack(stack) {}
// Returns true if the parsed node represents a constant index expression. // Returns true if the parsed node represents a constant index expression.
bool isValid() const { return mValid; } bool isValid() const { return mValid; }
virtual void visitSymbol(TIntermSymbol* symbol) { virtual void visitSymbol(TIntermSymbol *symbol)
{
// Only constants and loop indices are allowed in a // Only constants and loop indices are allowed in a
// constant index expression. // constant index expression.
if (mValid) { if (mValid)
{
mValid = (symbol->getQualifier() == EvqConst) || mValid = (symbol->getQualifier() == EvqConst) ||
IsLoopIndex(symbol, mLoopStack); (mLoopStack.findLoop(symbol));
} }
} }
private: private:
bool mValid; bool mValid;
const TLoopStack& mLoopStack;
};
// Traverses a node to check if it uses a loop index.
// If an int loop index is used in its body as a sampler array index,
// mark the loop for unroll.
class ValidateLoopIndexExpr : public TIntermTraverser {
public:
ValidateLoopIndexExpr(TLoopStack& stack)
: mUsesFloatLoopIndex(false),
mUsesIntLoopIndex(false),
mLoopStack(stack) {}
bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
virtual void visitSymbol(TIntermSymbol* symbol) {
if (IsLoopIndex(symbol, mLoopStack)) {
switch (symbol->getBasicType()) {
case EbtFloat:
mUsesFloatLoopIndex = true;
break;
case EbtInt:
mUsesIntLoopIndex = true;
MarkLoopForUnroll(symbol, mLoopStack);
break;
default:
UNREACHABLE();
}
}
}
private:
bool mUsesFloatLoopIndex;
bool mUsesIntLoopIndex;
TLoopStack& mLoopStack; TLoopStack& mLoopStack;
}; };
} // namespace
} // namespace anonymous
ValidateLimitations::ValidateLimitations(ShShaderType shaderType, ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
TInfoSinkBase& sink) TInfoSinkBase &sink)
: mShaderType(shaderType), : mShaderType(shaderType),
mSink(sink), mSink(sink),
mNumErrors(0) mNumErrors(0)
{ {
} }
bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node) bool ValidateLimitations::visitBinary(Visit, TIntermBinary *node)
{ {
// Check if loop index is modified in the loop body. // Check if loop index is modified in the loop body.
validateOperation(node, node->getLeft()); validateOperation(node, node->getLeft());
// Check indexing. // Check indexing.
switch (node->getOp()) { switch (node->getOp())
{
case EOpIndexDirect: case EOpIndexDirect:
validateIndexing(node);
break;
case EOpIndexIndirect: case EOpIndexIndirect:
#if defined(__APPLE__)
// Loop unrolling is a work-around for a Mac Cg compiler bug where it
// crashes when a sampler array's index is also the loop index.
// Once Apple fixes this bug, we should remove the code in this CL.
// See http://codereview.appspot.com/4331048/.
if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
(node->getLeft()->getAsSymbolNode())) {
TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
ValidateLoopIndexExpr validate(mLoopStack);
node->getRight()->traverse(&validate);
if (validate.usesFloatLoopIndex()) {
error(node->getLine(),
"sampler array index is float loop index",
"for");
}
}
}
#endif
validateIndexing(node); validateIndexing(node);
break; break;
default: break; default:
break;
} }
return true; return true;
} }
bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node) bool ValidateLimitations::visitUnary(Visit, TIntermUnary *node)
{ {
// Check if loop index is modified in the loop body. // Check if loop index is modified in the loop body.
validateOperation(node, node->getOperand()); validateOperation(node, node->getOperand());
...@@ -149,7 +82,7 @@ bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node) ...@@ -149,7 +82,7 @@ bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
return true; return true;
} }
bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node) bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate *node)
{ {
switch (node->getOp()) { switch (node->getOp()) {
case EOpFunctionCall: case EOpFunctionCall:
...@@ -161,22 +94,20 @@ bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node) ...@@ -161,22 +94,20 @@ bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
return true; return true;
} }
bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node) bool ValidateLimitations::visitLoop(Visit, TIntermLoop *node)
{ {
if (!validateLoopType(node)) if (!validateLoopType(node))
return false; return false;
TLoopInfo info; if (!validateForLoopHeader(node))
memset(&info, 0, sizeof(TLoopInfo));
info.loop = node;
if (!validateForLoopHeader(node, &info))
return false; return false;
TIntermNode* body = node->getBody(); TIntermNode *body = node->getBody();
if (body != NULL) { if (body != NULL)
mLoopStack.push_back(info); {
mLoopStack.push(node);
body->traverse(this); body->traverse(this);
mLoopStack.pop_back(); mLoopStack.pop();
} }
// The loop is fully processed - no need to visit children. // The loop is fully processed - no need to visit children.
...@@ -184,7 +115,7 @@ bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node) ...@@ -184,7 +115,7 @@ bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
} }
void ValidateLimitations::error(TSourceLoc loc, void ValidateLimitations::error(TSourceLoc loc,
const char *reason, const char* token) const char *reason, const char *token)
{ {
mSink.prefix(EPrefixError); mSink.prefix(EPrefixError);
mSink.location(loc); mSink.location(loc);
...@@ -197,12 +128,13 @@ bool ValidateLimitations::withinLoopBody() const ...@@ -197,12 +128,13 @@ bool ValidateLimitations::withinLoopBody() const
return !mLoopStack.empty(); return !mLoopStack.empty();
} }
bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const bool ValidateLimitations::isLoopIndex(TIntermSymbol *symbol)
{ {
return IsLoopIndex(symbol, mLoopStack); return mLoopStack.findLoop(symbol) != NULL;
} }
bool ValidateLimitations::validateLoopType(TIntermLoop* node) { bool ValidateLimitations::validateLoopType(TIntermLoop *node)
{
TLoopType type = node->getType(); TLoopType type = node->getType();
if (type == ELoopFor) if (type == ELoopFor)
return true; return true;
...@@ -214,8 +146,7 @@ bool ValidateLimitations::validateLoopType(TIntermLoop* node) { ...@@ -214,8 +146,7 @@ bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
return false; return false;
} }
bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node, bool ValidateLimitations::validateForLoopHeader(TIntermLoop *node)
TLoopInfo* info)
{ {
ASSERT(node->getType() == ELoopFor); ASSERT(node->getType() == ELoopFor);
...@@ -223,74 +154,81 @@ bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node, ...@@ -223,74 +154,81 @@ bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
// The for statement has the form: // The for statement has the form:
// for ( init-declaration ; condition ; expression ) statement // for ( init-declaration ; condition ; expression ) statement
// //
if (!validateForLoopInit(node, info)) int indexSymbolId = validateForLoopInit(node);
if (indexSymbolId < 0)
return false; return false;
if (!validateForLoopCond(node, info)) if (!validateForLoopCond(node, indexSymbolId))
return false; return false;
if (!validateForLoopExpr(node, info)) if (!validateForLoopExpr(node, indexSymbolId))
return false; return false;
return true; return true;
} }
bool ValidateLimitations::validateForLoopInit(TIntermLoop* node, int ValidateLimitations::validateForLoopInit(TIntermLoop *node)
TLoopInfo* info)
{ {
TIntermNode* init = node->getInit(); TIntermNode *init = node->getInit();
if (init == NULL) { if (init == NULL)
{
error(node->getLine(), "Missing init declaration", "for"); error(node->getLine(), "Missing init declaration", "for");
return false; return -1;
} }
// //
// init-declaration has the form: // init-declaration has the form:
// type-specifier identifier = constant-expression // type-specifier identifier = constant-expression
// //
TIntermAggregate* decl = init->getAsAggregate(); TIntermAggregate *decl = init->getAsAggregate();
if ((decl == NULL) || (decl->getOp() != EOpDeclaration)) { if ((decl == NULL) || (decl->getOp() != EOpDeclaration))
{
error(init->getLine(), "Invalid init declaration", "for"); error(init->getLine(), "Invalid init declaration", "for");
return false; return -1;
} }
// To keep things simple do not allow declaration list. // To keep things simple do not allow declaration list.
TIntermSequence& declSeq = decl->getSequence(); TIntermSequence &declSeq = decl->getSequence();
if (declSeq.size() != 1) { if (declSeq.size() != 1)
{
error(decl->getLine(), "Invalid init declaration", "for"); error(decl->getLine(), "Invalid init declaration", "for");
return false; return -1;
} }
TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); TIntermBinary *declInit = declSeq[0]->getAsBinaryNode();
if ((declInit == NULL) || (declInit->getOp() != EOpInitialize)) { if ((declInit == NULL) || (declInit->getOp() != EOpInitialize))
{
error(decl->getLine(), "Invalid init declaration", "for"); error(decl->getLine(), "Invalid init declaration", "for");
return false; return -1;
} }
TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
if (symbol == NULL) { if (symbol == NULL)
{
error(declInit->getLine(), "Invalid init declaration", "for"); error(declInit->getLine(), "Invalid init declaration", "for");
return false; return -1;
} }
// The loop index has type int or float. // The loop index has type int or float.
TBasicType type = symbol->getBasicType(); TBasicType type = symbol->getBasicType();
if ((type != EbtInt) && (type != EbtFloat)) { if ((type != EbtInt) && (type != EbtFloat))
{
error(symbol->getLine(), error(symbol->getLine(),
"Invalid type for loop index", getBasicString(type)); "Invalid type for loop index", getBasicString(type));
return false; return -1;
} }
// The loop index is initialized with constant expression. // The loop index is initialized with constant expression.
if (!isConstExpr(declInit->getRight())) { if (!isConstExpr(declInit->getRight()))
{
error(declInit->getLine(), error(declInit->getLine(),
"Loop index cannot be initialized with non-constant expression", "Loop index cannot be initialized with non-constant expression",
symbol->getSymbol().c_str()); symbol->getSymbol().c_str());
return false; return -1;
} }
info->index.id = symbol->getId(); return symbol->getId();
return true;
} }
bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, bool ValidateLimitations::validateForLoopCond(TIntermLoop *node,
TLoopInfo* info) int indexSymbolId)
{ {
TIntermNode* cond = node->getCondition(); TIntermNode *cond = node->getCondition();
if (cond == NULL) { if (cond == NULL)
{
error(node->getLine(), "Missing condition", "for"); error(node->getLine(), "Missing condition", "for");
return false; return false;
} }
...@@ -298,24 +236,28 @@ bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, ...@@ -298,24 +236,28 @@ bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
// condition has the form: // condition has the form:
// loop_index relational_operator constant_expression // loop_index relational_operator constant_expression
// //
TIntermBinary* binOp = cond->getAsBinaryNode(); TIntermBinary *binOp = cond->getAsBinaryNode();
if (binOp == NULL) { if (binOp == NULL)
{
error(node->getLine(), "Invalid condition", "for"); error(node->getLine(), "Invalid condition", "for");
return false; return false;
} }
// Loop index should be to the left of relational operator. // Loop index should be to the left of relational operator.
TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode(); TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
if (symbol == NULL) { if (symbol == NULL)
{
error(binOp->getLine(), "Invalid condition", "for"); error(binOp->getLine(), "Invalid condition", "for");
return false; return false;
} }
if (symbol->getId() != info->index.id) { if (symbol->getId() != indexSymbolId)
{
error(symbol->getLine(), error(symbol->getLine(),
"Expected loop index", symbol->getSymbol().c_str()); "Expected loop index", symbol->getSymbol().c_str());
return false; return false;
} }
// Relational operator is one of: > >= < <= == or !=. // Relational operator is one of: > >= < <= == or !=.
switch (binOp->getOp()) { switch (binOp->getOp())
{
case EOpEqual: case EOpEqual:
case EOpNotEqual: case EOpNotEqual:
case EOpLessThan: case EOpLessThan:
...@@ -330,7 +272,8 @@ bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, ...@@ -330,7 +272,8 @@ bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
break; break;
} }
// Loop index must be compared with a constant. // Loop index must be compared with a constant.
if (!isConstExpr(binOp->getRight())) { if (!isConstExpr(binOp->getRight()))
{
error(binOp->getLine(), error(binOp->getLine(),
"Loop index cannot be compared with non-constant expression", "Loop index cannot be compared with non-constant expression",
symbol->getSymbol().c_str()); symbol->getSymbol().c_str());
...@@ -340,11 +283,12 @@ bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, ...@@ -340,11 +283,12 @@ bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
return true; return true;
} }
bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, bool ValidateLimitations::validateForLoopExpr(TIntermLoop *node,
TLoopInfo* info) int indexSymbolId)
{ {
TIntermNode* expr = node->getExpression(); TIntermNode *expr = node->getExpression();
if (expr == NULL) { if (expr == NULL)
{
error(node->getLine(), "Missing expression", "for"); error(node->getLine(), "Missing expression", "for");
return false; return false;
} }
...@@ -358,32 +302,38 @@ bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, ...@@ -358,32 +302,38 @@ bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
// --loop_index // --loop_index
// The last two forms are not specified in the spec, but I am assuming // The last two forms are not specified in the spec, but I am assuming
// its an oversight. // its an oversight.
TIntermUnary* unOp = expr->getAsUnaryNode(); TIntermUnary *unOp = expr->getAsUnaryNode();
TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode(); TIntermBinary *binOp = unOp ? NULL : expr->getAsBinaryNode();
TOperator op = EOpNull; TOperator op = EOpNull;
TIntermSymbol* symbol = NULL; TIntermSymbol *symbol = NULL;
if (unOp != NULL) { if (unOp != NULL)
{
op = unOp->getOp(); op = unOp->getOp();
symbol = unOp->getOperand()->getAsSymbolNode(); symbol = unOp->getOperand()->getAsSymbolNode();
} else if (binOp != NULL) { }
else if (binOp != NULL)
{
op = binOp->getOp(); op = binOp->getOp();
symbol = binOp->getLeft()->getAsSymbolNode(); symbol = binOp->getLeft()->getAsSymbolNode();
} }
// The operand must be loop index. // The operand must be loop index.
if (symbol == NULL) { if (symbol == NULL)
{
error(expr->getLine(), "Invalid expression", "for"); error(expr->getLine(), "Invalid expression", "for");
return false; return false;
} }
if (symbol->getId() != info->index.id) { if (symbol->getId() != indexSymbolId)
{
error(symbol->getLine(), error(symbol->getLine(),
"Expected loop index", symbol->getSymbol().c_str()); "Expected loop index", symbol->getSymbol().c_str());
return false; return false;
} }
// The operator is one of: ++ -- += -=. // The operator is one of: ++ -- += -=.
switch (op) { switch (op)
{
case EOpPostIncrement: case EOpPostIncrement:
case EOpPostDecrement: case EOpPostDecrement:
case EOpPreIncrement: case EOpPreIncrement:
...@@ -400,8 +350,10 @@ bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, ...@@ -400,8 +350,10 @@ bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
} }
// Loop index must be incremented/decremented with a constant. // Loop index must be incremented/decremented with a constant.
if (binOp != NULL) { if (binOp != NULL)
if (!isConstExpr(binOp->getRight())) { {
if (!isConstExpr(binOp->getRight()))
{
error(binOp->getLine(), error(binOp->getLine(),
"Loop index cannot be modified by non-constant expression", "Loop index cannot be modified by non-constant expression",
symbol->getSymbol().c_str()); symbol->getSymbol().c_str());
...@@ -412,7 +364,7 @@ bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, ...@@ -412,7 +364,7 @@ bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
return true; return true;
} }
bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) bool ValidateLimitations::validateFunctionCall(TIntermAggregate *node)
{ {
ASSERT(node->getOp() == EOpFunctionCall); ASSERT(node->getOp() == EOpFunctionCall);
...@@ -424,8 +376,9 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) ...@@ -424,8 +376,9 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
typedef std::vector<size_t> ParamIndex; typedef std::vector<size_t> ParamIndex;
ParamIndex pIndex; ParamIndex pIndex;
TIntermSequence& params = node->getSequence(); TIntermSequence& params = node->getSequence();
for (TIntermSequence::size_type i = 0; i < params.size(); ++i) { for (TIntermSequence::size_type i = 0; i < params.size(); ++i)
TIntermSymbol* symbol = params[i]->getAsSymbolNode(); {
TIntermSymbol *symbol = params[i]->getAsSymbolNode();
if (symbol && isLoopIndex(symbol)) if (symbol && isLoopIndex(symbol))
pIndex.push_back(i); pIndex.push_back(i);
} }
...@@ -436,14 +389,16 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) ...@@ -436,14 +389,16 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
bool valid = true; bool valid = true;
TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable; TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
TSymbol* symbol = symbolTable.find(node->getName()); TSymbol *symbol = symbolTable.find(node->getName());
ASSERT(symbol && symbol->isFunction()); ASSERT(symbol && symbol->isFunction());
TFunction* function = static_cast<TFunction*>(symbol); TFunction *function = static_cast<TFunction *>(symbol);
for (ParamIndex::const_iterator i = pIndex.begin(); for (ParamIndex::const_iterator i = pIndex.begin();
i != pIndex.end(); ++i) { i != pIndex.end(); ++i)
const TParameter& param = function->getParam(*i); {
const TParameter &param = function->getParam(*i);
TQualifier qual = param.type->getQualifier(); TQualifier qual = param.type->getQualifier();
if ((qual == EvqOut) || (qual == EvqInOut)) { if ((qual == EvqOut) || (qual == EvqInOut))
{
error(params[*i]->getLine(), error(params[*i]->getLine(),
"Loop index cannot be used as argument to a function out or inout parameter", "Loop index cannot be used as argument to a function out or inout parameter",
params[*i]->getAsSymbolNode()->getSymbol().c_str()); params[*i]->getAsSymbolNode()->getSymbol().c_str());
...@@ -454,14 +409,16 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) ...@@ -454,14 +409,16 @@ bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
return valid; return valid;
} }
bool ValidateLimitations::validateOperation(TIntermOperator* node, bool ValidateLimitations::validateOperation(TIntermOperator *node,
TIntermNode* operand) { TIntermNode* operand)
{
// Check if loop index is modified in the loop body. // Check if loop index is modified in the loop body.
if (!withinLoopBody() || !node->isAssignment()) if (!withinLoopBody() || !node->isAssignment())
return true; return true;
const TIntermSymbol* symbol = operand->getAsSymbolNode(); TIntermSymbol *symbol = operand->getAsSymbolNode();
if (symbol && isLoopIndex(symbol)) { if (symbol && isLoopIndex(symbol))
{
error(node->getLine(), error(node->getLine(),
"Loop index cannot be statically assigned to within the body of the loop", "Loop index cannot be statically assigned to within the body of the loop",
symbol->getSymbol().c_str()); symbol->getSymbol().c_str());
...@@ -469,13 +426,13 @@ bool ValidateLimitations::validateOperation(TIntermOperator* node, ...@@ -469,13 +426,13 @@ bool ValidateLimitations::validateOperation(TIntermOperator* node,
return true; return true;
} }
bool ValidateLimitations::isConstExpr(TIntermNode* node) bool ValidateLimitations::isConstExpr(TIntermNode *node)
{ {
ASSERT(node != NULL); ASSERT(node != NULL);
return node->getAsConstantUnion() != NULL; return node->getAsConstantUnion() != NULL;
} }
bool ValidateLimitations::isConstIndexExpr(TIntermNode* node) bool ValidateLimitations::isConstIndexExpr(TIntermNode *node)
{ {
ASSERT(node != NULL); ASSERT(node != NULL);
...@@ -484,15 +441,16 @@ bool ValidateLimitations::isConstIndexExpr(TIntermNode* node) ...@@ -484,15 +441,16 @@ bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
return validate.isValid(); return validate.isValid();
} }
bool ValidateLimitations::validateIndexing(TIntermBinary* node) bool ValidateLimitations::validateIndexing(TIntermBinary *node)
{ {
ASSERT((node->getOp() == EOpIndexDirect) || ASSERT((node->getOp() == EOpIndexDirect) ||
(node->getOp() == EOpIndexIndirect)); (node->getOp() == EOpIndexIndirect));
bool valid = true; bool valid = true;
TIntermTyped* index = node->getRight(); TIntermTyped *index = node->getRight();
// The index expression must have integral type. // The index expression must have integral type.
if (!index->isScalar() || (index->getBasicType() != EbtInt)) { if (!index->isScalar() || (index->getBasicType() != EbtInt))
{
error(index->getLine(), error(index->getLine(),
"Index expression must have integral type", "Index expression must have integral type",
index->getCompleteString().c_str()); index->getCompleteString().c_str());
...@@ -500,10 +458,11 @@ bool ValidateLimitations::validateIndexing(TIntermBinary* node) ...@@ -500,10 +458,11 @@ bool ValidateLimitations::validateIndexing(TIntermBinary* node)
} }
// The index expession must be a constant-index-expression unless // The index expession must be a constant-index-expression unless
// the operand is a uniform in a vertex shader. // the operand is a uniform in a vertex shader.
TIntermTyped* operand = node->getLeft(); TIntermTyped *operand = node->getLeft();
bool skip = (mShaderType == SH_VERTEX_SHADER) && bool skip = (mShaderType == SH_VERTEX_SHADER) &&
(operand->getQualifier() == EvqUniform); (operand->getQualifier() == EvqUniform);
if (!skip && !isConstIndexExpr(index)) { if (!skip && !isConstIndexExpr(index))
{
error(index->getLine(), "Index expression must be constant", "[]"); error(index->getLine(), "Index expression must be constant", "[]");
valid = false; valid = false;
} }
......
...@@ -6,53 +6,50 @@ ...@@ -6,53 +6,50 @@
#include "GLSLANG/ShaderLang.h" #include "GLSLANG/ShaderLang.h"
#include "compiler/translator/intermediate.h" #include "compiler/translator/intermediate.h"
#include "compiler/translator/LoopInfo.h"
class TInfoSinkBase; class TInfoSinkBase;
struct TLoopInfo {
struct TIndex {
int id; // symbol id.
} index;
TIntermLoop* loop;
};
typedef TVector<TLoopInfo> TLoopStack;
// Traverses intermediate tree to ensure that the shader does not exceed the // Traverses intermediate tree to ensure that the shader does not exceed the
// minimum functionality mandated in GLSL 1.0 spec, Appendix A. // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
class ValidateLimitations : public TIntermTraverser { class ValidateLimitations : public TIntermTraverser
public: {
ValidateLimitations(ShShaderType shaderType, TInfoSinkBase& sink); public:
ValidateLimitations(ShShaderType shaderType, TInfoSinkBase &sink);
int numErrors() const { return mNumErrors; } int numErrors() const { return mNumErrors; }
virtual bool visitBinary(Visit, TIntermBinary*); virtual bool visitBinary(Visit, TIntermBinary *);
virtual bool visitUnary(Visit, TIntermUnary*); virtual bool visitUnary(Visit, TIntermUnary *);
virtual bool visitAggregate(Visit, TIntermAggregate*); virtual bool visitAggregate(Visit, TIntermAggregate *);
virtual bool visitLoop(Visit, TIntermLoop*); virtual bool visitLoop(Visit, TIntermLoop *);
private: private:
void error(TSourceLoc loc, const char *reason, const char* token); void error(TSourceLoc loc, const char *reason, const char *token);
bool withinLoopBody() const; bool withinLoopBody() const;
bool isLoopIndex(const TIntermSymbol* symbol) const; bool isLoopIndex(TIntermSymbol *symbol);
bool validateLoopType(TIntermLoop* node); bool validateLoopType(TIntermLoop *node);
bool validateForLoopHeader(TIntermLoop* node, TLoopInfo* info);
bool validateForLoopInit(TIntermLoop* node, TLoopInfo* info); bool validateForLoopHeader(TIntermLoop *node);
bool validateForLoopCond(TIntermLoop* node, TLoopInfo* info); // If valid, return the index symbol id; Otherwise, return -1.
bool validateForLoopExpr(TIntermLoop* node, TLoopInfo* info); int validateForLoopInit(TIntermLoop *node);
bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
// Returns true if none of the loop indices is used as the argument to // Returns true if none of the loop indices is used as the argument to
// the given function out or inout parameter. // the given function out or inout parameter.
bool validateFunctionCall(TIntermAggregate* node); bool validateFunctionCall(TIntermAggregate *node);
bool validateOperation(TIntermOperator* node, TIntermNode* operand); bool validateOperation(TIntermOperator *node, TIntermNode *operand);
// Returns true if indexing does not exceed the minimum functionality // Returns true if indexing does not exceed the minimum functionality
// mandated in GLSL 1.0 spec, Appendix A, Section 5. // mandated in GLSL 1.0 spec, Appendix A, Section 5.
bool isConstExpr(TIntermNode* node); bool isConstExpr(TIntermNode *node);
bool isConstIndexExpr(TIntermNode* node); bool isConstIndexExpr(TIntermNode *node);
bool validateIndexing(TIntermBinary* node); bool validateIndexing(TIntermBinary *node);
ShShaderType mShaderType; ShShaderType mShaderType;
TInfoSinkBase& mSink; TInfoSinkBase &mSink;
int mNumErrors; int mNumErrors;
TLoopStack mLoopStack; TLoopStack mLoopStack;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment