Commit 0b8d4eb2 by zmo@google.com

Unroll for-loop if sampler array uses loop index as its index.

If inside a for-loop, sampler array index is the loop index, Mac cg compiler will crash. This CL unroll the loop in such situation. The behavior is: 1) If the for-loop index is a float, we reject the shader. 2) If it is an integer, we unroll the for-loop. Things that should be done in the future are: 1) Add line number macros. 2) Add a limit to unroll iteration count. anglebug=94 Review URL: http://codereview.appspot.com/4331048 git-svn-id: https://angleproject.googlecode.com/svn/trunk@606 736b8ea6-26fd-11df-bfd4-992fa37f6226
parent f02c9e62
...@@ -100,6 +100,8 @@ ...@@ -100,6 +100,8 @@
], ],
'sources': [ 'sources': [
'compiler/CodeGenGLSL.cpp', 'compiler/CodeGenGLSL.cpp',
'compiler/ForLoopUnroll.cpp',
'compiler/ForLoopUnroll.h',
'compiler/OutputGLSL.cpp', 'compiler/OutputGLSL.cpp',
'compiler/OutputGLSL.h', 'compiler/OutputGLSL.h',
'compiler/TranslatorGLSL.cpp', 'compiler/TranslatorGLSL.cpp',
......
#define MAJOR_VERSION 0 #define MAJOR_VERSION 0
#define MINOR_VERSION 0 #define MINOR_VERSION 0
#define BUILD_VERSION 0 #define BUILD_VERSION 0
#define BUILD_REVISION 605 #define BUILD_REVISION 606
#define STRINGIFY(x) #x #define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x) #define MACRO_STRINGIFY(x) STRINGIFY(x)
......
//
// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/ForLoopUnroll.h"
void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
{
ASSERT(node->getType() == ELoopFor);
ASSERT(node->getUnrollFlag());
TIntermNode* init = node->getInit();
ASSERT(init != NULL);
TIntermAggregate* decl = init->getAsAggregate();
ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
TIntermSequence& declSeq = decl->getSequence();
ASSERT(declSeq.size() == 1);
TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
ASSERT(symbol != NULL);
ASSERT(symbol->getBasicType() == EbtInt);
info.id = symbol->getId();
ASSERT(declInit->getRight() != NULL);
TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
ASSERT(initNode != NULL);
info.initValue = evaluateIntConstant(initNode);
info.currentValue = info.initValue;
TIntermNode* cond = node->getCondition();
ASSERT(cond != NULL);
TIntermBinary* binOp = cond->getAsBinaryNode();
ASSERT(binOp != NULL);
ASSERT(binOp->getRight() != NULL);
ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
info.incrementValue = getLoopIncrement(node);
info.stopValue = evaluateIntConstant(
binOp->getRight()->getAsConstantUnion());
info.op = binOp->getOp();
}
void ForLoopUnroll::Step()
{
ASSERT(mLoopIndexStack.size() > 0);
TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
info.currentValue += info.incrementValue;
}
bool ForLoopUnroll::SatisfiesLoopCondition()
{
ASSERT(mLoopIndexStack.size() > 0);
TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
// Relational operator is one of: > >= < <= == or !=.
switch (info.op) {
case EOpEqual:
return (info.currentValue == info.stopValue);
case EOpNotEqual:
return (info.currentValue != info.stopValue);
case EOpLessThan:
return (info.currentValue < info.stopValue);
case EOpGreaterThan:
return (info.currentValue > info.stopValue);
case EOpLessThanEqual:
return (info.currentValue <= info.stopValue);
case EOpGreaterThanEqual:
return (info.currentValue >= info.stopValue);
default:
UNREACHABLE();
}
return false;
}
bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol)
{
for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
i != mLoopIndexStack.end();
++i) {
if (i->id == symbol->getId())
return true;
}
return false;
}
int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol)
{
for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
i != mLoopIndexStack.end();
++i) {
if (i->id == symbol->getId())
return i->currentValue;
}
UNREACHABLE();
return false;
}
void ForLoopUnroll::Push(TLoopIndexInfo& info)
{
mLoopIndexStack.push_back(info);
}
void ForLoopUnroll::Pop()
{
mLoopIndexStack.pop_back();
}
int ForLoopUnroll::getLoopIncrement(TIntermLoop* node)
{
TIntermNode* expr = node->getExpression();
ASSERT(expr != NULL);
// for expression has one of the following forms:
// loop_index++
// loop_index--
// loop_index += constant_expression
// loop_index -= constant_expression
// ++loop_index
// --loop_index
// The last two forms are not specified in the spec, but I am assuming
// its an oversight.
TIntermUnary* unOp = expr->getAsUnaryNode();
TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
TOperator op = EOpNull;
TIntermConstantUnion* incrementNode = NULL;
if (unOp != NULL) {
op = unOp->getOp();
} else if (binOp != NULL) {
op = binOp->getOp();
ASSERT(binOp->getRight() != NULL);
incrementNode = binOp->getRight()->getAsConstantUnion();
ASSERT(incrementNode != NULL);
}
int increment = 0;
// The operator is one of: ++ -- += -=.
switch (op) {
case EOpPostIncrement:
case EOpPreIncrement:
ASSERT((unOp != NULL) && (binOp == NULL));
increment = 1;
break;
case EOpPostDecrement:
case EOpPreDecrement:
ASSERT((unOp != NULL) && (binOp == NULL));
increment = -1;
break;
case EOpAddAssign:
ASSERT((unOp == NULL) && (binOp != NULL));
increment = evaluateIntConstant(incrementNode);
break;
case EOpSubAssign:
ASSERT((unOp == NULL) && (binOp != NULL));
increment = - evaluateIntConstant(incrementNode);
break;
default:
ASSERT(false);
}
return increment;
}
int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
{
ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
return node->getUnionArrayPointer()->getIConst();
}
//
// Copyright (c) 2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/intermediate.h"
struct TLoopIndexInfo {
int id;
int initValue;
int stopValue;
int incrementValue;
TOperator op;
int currentValue;
};
class ForLoopUnroll {
public:
ForLoopUnroll() { }
void FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info);
// Update the info.currentValue for the next loop iteration.
void Step();
// Return false if loop condition is no longer satisfied.
bool SatisfiesLoopCondition();
// Check if the symbol is the index of a loop that's unrolled.
bool NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol);
// Return the current value of a given loop index symbol.
int GetLoopIndexValue(TIntermSymbol* symbol);
void Push(TLoopIndexInfo& info);
void Pop();
private:
int getLoopIncrement(TIntermLoop* node);
int evaluateIntConstant(TIntermConstantUnion* node);
TVector<TLoopIndexInfo> mLoopIndexStack;
};
// //
// Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved. // Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be // Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. // found in the LICENSE file.
// //
...@@ -195,7 +195,10 @@ const ConstantUnion* TOutputGLSL::writeConstantUnion(const TType& type, ...@@ -195,7 +195,10 @@ const ConstantUnion* TOutputGLSL::writeConstantUnion(const TType& type,
void TOutputGLSL::visitSymbol(TIntermSymbol* node) void TOutputGLSL::visitSymbol(TIntermSymbol* node)
{ {
TInfoSinkBase& out = objSink(); TInfoSinkBase& out = objSink();
out << node->getSymbol(); if (mLoopUnroll.NeedsToReplaceSymbolWithValue(node))
out << mLoopUnroll.GetLoopIndexValue(node);
else
out << node->getSymbol();
if (mDeclaringVariables && node->getType().isArray()) if (mDeclaringVariables && node->getType().isArray())
out << arrayBrackets(node->getType()); out << arrayBrackets(node->getType());
...@@ -615,18 +618,20 @@ bool TOutputGLSL::visitLoop(Visit visit, TIntermLoop* node) ...@@ -615,18 +618,20 @@ bool TOutputGLSL::visitLoop(Visit visit, TIntermLoop* node)
TLoopType loopType = node->getType(); TLoopType loopType = node->getType();
if (loopType == ELoopFor) // for loop if (loopType == ELoopFor) // for loop
{ {
out << "for ("; if (!node->getUnrollFlag()) {
if (node->getInit()) out << "for (";
node->getInit()->traverse(this); if (node->getInit())
out << "; "; node->getInit()->traverse(this);
out << "; ";
if (node->getCondition())
node->getCondition()->traverse(this); if (node->getCondition())
out << "; "; node->getCondition()->traverse(this);
out << "; ";
if (node->getExpression())
node->getExpression()->traverse(this); if (node->getExpression())
out << ")\n"; node->getExpression()->traverse(this);
out << ")\n";
}
} }
else if (loopType == ELoopWhile) // while loop else if (loopType == ELoopWhile) // while loop
{ {
...@@ -642,7 +647,22 @@ bool TOutputGLSL::visitLoop(Visit visit, TIntermLoop* node) ...@@ -642,7 +647,22 @@ bool TOutputGLSL::visitLoop(Visit visit, TIntermLoop* node)
} }
// Loop body. // Loop body.
visitCodeBlock(node->getBody()); if (node->getUnrollFlag())
{
TLoopIndexInfo indexInfo;
mLoopUnroll.FillLoopIndexInfo(node, indexInfo);
mLoopUnroll.Push(indexInfo);
while (mLoopUnroll.SatisfiesLoopCondition())
{
visitCodeBlock(node->getBody());
mLoopUnroll.Step();
}
mLoopUnroll.Pop();
}
else
{
visitCodeBlock(node->getBody());
}
// Loop footer. // Loop footer.
if (loopType == ELoopDoWhile) // do-while loop if (loopType == ELoopDoWhile) // do-while loop
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <set> #include <set>
#include "compiler/ForLoopUnroll.h"
#include "compiler/intermediate.h" #include "compiler/intermediate.h"
#include "compiler/ParseHelper.h" #include "compiler/ParseHelper.h"
...@@ -44,6 +45,8 @@ private: ...@@ -44,6 +45,8 @@ private:
// declared only once. // declared only once.
typedef std::set<TString> DeclaredStructs; typedef std::set<TString> DeclaredStructs;
DeclaredStructs mDeclaredStructs; DeclaredStructs mDeclaredStructs;
ForLoopUnroll mLoopUnroll;
}; };
#endif // CROSSCOMPILERGLSL_OUTPUTGLSL_H_ #endif // CROSSCOMPILERGLSL_OUTPUTGLSL_H_
...@@ -17,6 +17,17 @@ bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) { ...@@ -17,6 +17,17 @@ bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
return false; return false;
} }
void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
if (i->index.id == symbol->getId()) {
ASSERT(i->loop != NULL);
i->loop->setUnrollFlag(true);
return;
}
}
UNREACHABLE();
}
// Traverses a node to check if it represents a constant index expression. // Traverses a node to check if it represents a constant index expression.
// Definition: // Definition:
// constant-index-expressions are a superset of constant-expressions. // constant-index-expressions are a superset of constant-expressions.
...@@ -54,6 +65,48 @@ private: ...@@ -54,6 +65,48 @@ private:
bool mValid; bool mValid;
const TLoopStack& mLoopStack; const TLoopStack& mLoopStack;
}; };
// Traverses a node to check if it uses a loop index.
// If an int loop index is used in its body as a sampler array index,
// mark the loop for unroll.
class ValidateLoopIndexExpr : public TIntermTraverser {
public:
ValidateLoopIndexExpr(TLoopStack& stack)
: mUsesFloatLoopIndex(false),
mUsesIntLoopIndex(false),
mLoopStack(stack) {}
bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
virtual void visitSymbol(TIntermSymbol* symbol) {
if (IsLoopIndex(symbol, mLoopStack)) {
switch (symbol->getBasicType()) {
case EbtFloat:
mUsesFloatLoopIndex = true;
break;
case EbtInt:
mUsesIntLoopIndex = true;
MarkLoopForUnroll(symbol, mLoopStack);
break;
default:
UNREACHABLE();
}
}
}
virtual void visitConstantUnion(TIntermConstantUnion*) {}
virtual bool visitBinary(Visit, TIntermBinary*) { return true; }
virtual bool visitUnary(Visit, TIntermUnary*) { return true; }
virtual bool visitSelection(Visit, TIntermSelection*) { return true; }
virtual bool visitAggregate(Visit, TIntermAggregate*) { return true; }
virtual bool visitLoop(Visit, TIntermLoop*) { return true; }
virtual bool visitBranch(Visit, TIntermBranch*) { return true; }
private:
bool mUsesFloatLoopIndex;
bool mUsesIntLoopIndex;
TLoopStack& mLoopStack;
};
} // namespace } // namespace
ValidateLimitations::ValidateLimitations(ShShaderType shaderType, ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
...@@ -80,7 +133,28 @@ bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node) ...@@ -80,7 +133,28 @@ bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
// Check indexing. // Check indexing.
switch (node->getOp()) { switch (node->getOp()) {
case EOpIndexDirect: case EOpIndexDirect:
validateIndexing(node);
break;
case EOpIndexIndirect: case EOpIndexIndirect:
#if defined(__APPLE__)
// Loop unrolling is a work-around for a Mac Cg compiler bug where it
// crashes when a sampler array's index is also the loop index.
// Once Apple fixes this bug, we should remove the code in this CL.
// See http://codereview.appspot.com/4331048/.
if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
(node->getLeft()->getAsSymbolNode())) {
TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
ValidateLoopIndexExpr validate(mLoopStack);
node->getRight()->traverse(&validate);
if (validate.usesFloatLoopIndex()) {
error(node->getLine(),
"sampler array index is float loop index",
"for");
}
}
}
#endif
validateIndexing(node); validateIndexing(node);
break; break;
default: break; default: break;
...@@ -120,6 +194,7 @@ bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node) ...@@ -120,6 +194,7 @@ bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
TLoopInfo info; TLoopInfo info;
memset(&info, 0, sizeof(TLoopInfo)); memset(&info, 0, sizeof(TLoopInfo));
info.loop = node;
if (!validateForLoopHeader(node, &info)) if (!validateForLoopHeader(node, &info))
return false; return false;
......
...@@ -13,6 +13,7 @@ struct TLoopInfo { ...@@ -13,6 +13,7 @@ struct TLoopInfo {
struct TIndex { struct TIndex {
int id; // symbol id. int id; // symbol id.
} index; } index;
TIntermLoop* loop;
}; };
typedef TVector<TLoopInfo> TLoopStack; typedef TVector<TLoopInfo> TLoopStack;
......
...@@ -279,7 +279,8 @@ public: ...@@ -279,7 +279,8 @@ public:
init(aInit), init(aInit),
cond(aCond), cond(aCond),
expr(aExpr), expr(aExpr),
body(aBody) { } body(aBody),
unrollFlag(false) { }
virtual TIntermLoop* getAsLoopNode() { return this; } virtual TIntermLoop* getAsLoopNode() { return this; }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
...@@ -290,12 +291,17 @@ public: ...@@ -290,12 +291,17 @@ public:
TIntermTyped* getExpression() { return expr; } TIntermTyped* getExpression() { return expr; }
TIntermNode* getBody() { return body; } TIntermNode* getBody() { return body; }
void setUnrollFlag(bool flag) { unrollFlag = flag; }
bool getUnrollFlag() { return unrollFlag; }
protected: protected:
TLoopType type; TLoopType type;
TIntermNode* init; // for-loop initialization TIntermNode* init; // for-loop initialization
TIntermTyped* cond; // loop exit condition TIntermTyped* cond; // loop exit condition
TIntermTyped* expr; // for-loop expression TIntermTyped* expr; // for-loop expression
TIntermNode* body; // loop body TIntermNode* body; // loop body
bool unrollFlag; // Whether the loop should be unrolled or not.
}; };
// //
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment