Commit 5e70cf9d by Zhenyao Mo

Add an option to unfold short circuiting in AST.

We replace "a || b" with "a ? true : b", "a && b" with "a ? b : false". This is to work around short circuiting bug in Mac drivers. ANGLEBUG=482 TEST=webgl conformance tests R=alokp@chromium.org, kbr@chromium.org Review URL: https://codereview.appspot.com/14529048
parent 59b77858
...@@ -189,6 +189,13 @@ typedef enum { ...@@ -189,6 +189,13 @@ typedef enum {
// fragment shader. It is intended as a workaround for drivers which // fragment shader. It is intended as a workaround for drivers which
// incorrectly fail to link programs if gl_Position is not written. // incorrectly fail to link programs if gl_Position is not written.
SH_INIT_GL_POSITION = 0x8000, SH_INIT_GL_POSITION = 0x8000,
// This flag replaces
// "a && b" with "a ? b : false",
// "a || b" with "a ? true : b".
// This is to work around a MacOSX driver bug that |b| is executed
// independent of |a|'s value.
SH_UNFOLD_SHORT_CIRCUIT = 0x10000,
} ShCompileOptions; } ShCompileOptions;
// Defines alternate strategies for implementing array index clamping. // Defines alternate strategies for implementing array index clamping.
......
...@@ -82,6 +82,8 @@ ...@@ -82,6 +82,8 @@
'compiler/TranslatorHLSL.cpp', 'compiler/TranslatorHLSL.cpp',
'compiler/TranslatorHLSL.h', 'compiler/TranslatorHLSL.h',
'compiler/Types.h', 'compiler/Types.h',
'compiler/UnfoldShortCircuitAST.cpp',
'compiler/UnfoldShortCircuitAST.h',
'compiler/UnfoldShortCircuit.cpp', 'compiler/UnfoldShortCircuit.cpp',
'compiler/UnfoldShortCircuit.h', 'compiler/UnfoldShortCircuit.h',
'compiler/Uniform.cpp', 'compiler/Uniform.cpp',
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "compiler/ParseHelper.h" #include "compiler/ParseHelper.h"
#include "compiler/RenameFunction.h" #include "compiler/RenameFunction.h"
#include "compiler/ShHandle.h" #include "compiler/ShHandle.h"
#include "compiler/UnfoldShortCircuitAST.h"
#include "compiler/ValidateLimitations.h" #include "compiler/ValidateLimitations.h"
#include "compiler/VariablePacker.h" #include "compiler/VariablePacker.h"
#include "compiler/depgraph/DependencyGraph.h" #include "compiler/depgraph/DependencyGraph.h"
...@@ -194,6 +195,12 @@ bool TCompiler::compile(const char* const shaderStrings[], ...@@ -194,6 +195,12 @@ bool TCompiler::compile(const char* const shaderStrings[],
root->traverse(&initGLPosition); root->traverse(&initGLPosition);
} }
if (success && (compileOptions & SH_UNFOLD_SHORT_CIRCUIT)) {
UnfoldShortCircuitAST unfoldShortCircuit;
root->traverse(&unfoldShortCircuit);
unfoldShortCircuit.updateTree();
}
if (success && (compileOptions & SH_VARIABLES)) { if (success && (compileOptions & SH_VARIABLES)) {
collectVariables(root); collectVariables(root);
if (compileOptions & SH_ENFORCE_PACKING_RESTRICTIONS) { if (compileOptions & SH_ENFORCE_PACKING_RESTRICTIONS) {
......
...@@ -51,7 +51,7 @@ void TIntermBinary::traverse(TIntermTraverser *it) ...@@ -51,7 +51,7 @@ void TIntermBinary::traverse(TIntermTraverser *it)
// //
if (visit) if (visit)
{ {
it->incrementDepth(); it->incrementDepth(this);
if (it->rightToLeft) if (it->rightToLeft)
{ {
...@@ -98,7 +98,7 @@ void TIntermUnary::traverse(TIntermTraverser *it) ...@@ -98,7 +98,7 @@ void TIntermUnary::traverse(TIntermTraverser *it)
visit = it->visitUnary(PreVisit, this); visit = it->visitUnary(PreVisit, this);
if (visit) { if (visit) {
it->incrementDepth(); it->incrementDepth(this);
operand->traverse(it); operand->traverse(it);
it->decrementDepth(); it->decrementDepth();
} }
...@@ -119,7 +119,7 @@ void TIntermAggregate::traverse(TIntermTraverser *it) ...@@ -119,7 +119,7 @@ void TIntermAggregate::traverse(TIntermTraverser *it)
if (visit) if (visit)
{ {
it->incrementDepth(); it->incrementDepth(this);
if (it->rightToLeft) if (it->rightToLeft)
{ {
...@@ -166,7 +166,7 @@ void TIntermSelection::traverse(TIntermTraverser *it) ...@@ -166,7 +166,7 @@ void TIntermSelection::traverse(TIntermTraverser *it)
visit = it->visitSelection(PreVisit, this); visit = it->visitSelection(PreVisit, this);
if (visit) { if (visit) {
it->incrementDepth(); it->incrementDepth(this);
if (it->rightToLeft) { if (it->rightToLeft) {
if (falseBlock) if (falseBlock)
falseBlock->traverse(it); falseBlock->traverse(it);
...@@ -199,7 +199,7 @@ void TIntermLoop::traverse(TIntermTraverser *it) ...@@ -199,7 +199,7 @@ void TIntermLoop::traverse(TIntermTraverser *it)
if (visit) if (visit)
{ {
it->incrementDepth(); it->incrementDepth(this);
if (it->rightToLeft) if (it->rightToLeft)
{ {
...@@ -248,7 +248,7 @@ void TIntermBranch::traverse(TIntermTraverser *it) ...@@ -248,7 +248,7 @@ void TIntermBranch::traverse(TIntermTraverser *it)
visit = it->visitBranch(PreVisit, this); visit = it->visitBranch(PreVisit, this);
if (visit && expression) { if (visit && expression) {
it->incrementDepth(); it->incrementDepth(this);
expression->traverse(it); expression->traverse(it);
it->decrementDepth(); it->decrementDepth();
} }
......
...@@ -19,11 +19,13 @@ ...@@ -19,11 +19,13 @@
bool CompareStructure(const TType& leftNodeType, ConstantUnion* rightUnionArray, ConstantUnion* leftUnionArray); bool CompareStructure(const TType& leftNodeType, ConstantUnion* rightUnionArray, ConstantUnion* leftUnionArray);
static TPrecision GetHigherPrecision( TPrecision left, TPrecision right ){ static TPrecision GetHigherPrecision(TPrecision left, TPrecision right)
{
return left > right ? left : right; return left > right ? left : right;
} }
const char* getOperatorString(TOperator op) { const char* getOperatorString(TOperator op)
{
switch (op) { switch (op) {
case EOpInitialize: return "="; case EOpInitialize: return "=";
case EOpAssign: return "="; case EOpAssign: return "=";
...@@ -742,6 +744,63 @@ void TIntermediate::remove(TIntermNode* root) ...@@ -742,6 +744,63 @@ void TIntermediate::remove(TIntermNode* root)
// //
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
#define REPLACE_IF_IS(node, type, original, replacement) \
if (node == original) { \
node = static_cast<type *>(replacement); \
return true; \
}
bool TIntermLoop::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(init, TIntermNode, original, replacement);
REPLACE_IF_IS(cond, TIntermTyped, original, replacement);
REPLACE_IF_IS(expr, TIntermTyped, original, replacement);
REPLACE_IF_IS(body, TIntermNode, original, replacement);
return false;
}
bool TIntermBranch::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(expression, TIntermTyped, original, replacement);
return false;
}
bool TIntermBinary::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(left, TIntermTyped, original, replacement);
REPLACE_IF_IS(right, TIntermTyped, original, replacement);
return false;
}
bool TIntermUnary::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(operand, TIntermTyped, original, replacement);
return false;
}
bool TIntermAggregate::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
for (size_t ii = 0; ii < sequence.size(); ++ii)
{
REPLACE_IF_IS(sequence[ii], TIntermNode, original, replacement);
}
return false;
}
bool TIntermSelection::replaceChildNode(
TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(condition, TIntermTyped, original, replacement);
REPLACE_IF_IS(trueBlock, TIntermNode, original, replacement);
REPLACE_IF_IS(falseBlock, TIntermNode, original, replacement);
return false;
}
// //
// Say whether or not an operation node changes the value of a variable. // Say whether or not an operation node changes the value of a variable.
// //
...@@ -796,6 +855,7 @@ bool TIntermOperator::isConstructor() const ...@@ -796,6 +855,7 @@ bool TIntermOperator::isConstructor() const
return false; return false;
} }
} }
// //
// Make sure the type of a unary operator is appropriate for its // Make sure the type of a unary operator is appropriate for its
// combination of operation and operand type. // combination of operation and operand type.
......
...@@ -435,7 +435,7 @@ bool TOutputGLSLBase::visitSelection(Visit visit, TIntermSelection* node) ...@@ -435,7 +435,7 @@ bool TOutputGLSLBase::visitSelection(Visit visit, TIntermSelection* node)
node->getCondition()->traverse(this); node->getCondition()->traverse(this);
out << ")\n"; out << ")\n";
incrementDepth(); incrementDepth(node);
visitCodeBlock(node->getTrueBlock()); visitCodeBlock(node->getTrueBlock());
if (node->getFalseBlock()) if (node->getFalseBlock())
...@@ -460,7 +460,7 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate* node) ...@@ -460,7 +460,7 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate* node)
// Scope the sequences except when at the global scope. // Scope the sequences except when at the global scope.
if (depth > 0) out << "{\n"; if (depth > 0) out << "{\n";
incrementDepth(); incrementDepth(node);
const TIntermSequence& sequence = node->getSequence(); const TIntermSequence& sequence = node->getSequence();
for (TIntermSequence::const_iterator iter = sequence.begin(); for (TIntermSequence::const_iterator iter = sequence.begin();
iter != sequence.end(); ++iter) iter != sequence.end(); ++iter)
...@@ -498,7 +498,7 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate* node) ...@@ -498,7 +498,7 @@ bool TOutputGLSLBase::visitAggregate(Visit visit, TIntermAggregate* node)
writeVariableType(node->getType()); writeVariableType(node->getType());
out << " " << hashFunctionName(node->getName()); out << " " << hashFunctionName(node->getName());
incrementDepth(); incrementDepth(node);
// Function definition node contains one or two children nodes // Function definition node contains one or two children nodes
// representing function parameters and function body. The latter // representing function parameters and function body. The latter
// is not present in case of empty function bodies. // is not present in case of empty function bodies.
...@@ -638,7 +638,7 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node) ...@@ -638,7 +638,7 @@ bool TOutputGLSLBase::visitLoop(Visit visit, TIntermLoop* node)
{ {
TInfoSinkBase& out = objSink(); TInfoSinkBase& out = objSink();
incrementDepth(); incrementDepth(node);
// Loop header. // Loop header.
TLoopType loopType = node->getType(); TLoopType loopType = node->getType();
if (loopType == ELoopFor) // for loop if (loopType == ELoopFor) // for loop
......
//
// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
#include "compiler/UnfoldShortCircuitAST.h"
namespace
{
// "x || y" is equivalent to "x ? true : y".
TIntermSelection *UnfoldOR(TIntermTyped *x, TIntermTyped *y)
{
const TType boolType(EbtBool, EbpUndefined);
ConstantUnion *u = new ConstantUnion;
u->setBConst(true);
TIntermConstantUnion *trueNode = new TIntermConstantUnion(
u, TType(EbtBool, EbpUndefined, EvqConst, 1));
return new TIntermSelection(x, trueNode, y, boolType);
}
// "x && y" is equivalent to "x ? y : false".
TIntermSelection *UnfoldAND(TIntermTyped *x, TIntermTyped *y)
{
const TType boolType(EbtBool, EbpUndefined);
ConstantUnion *u = new ConstantUnion;
u->setBConst(false);
TIntermConstantUnion *falseNode = new TIntermConstantUnion(
u, TType(EbtBool, EbpUndefined, EvqConst, 1));
return new TIntermSelection(x, y, falseNode, boolType);
}
} // namespace anonymous
bool UnfoldShortCircuitAST::visitBinary(Visit visit, TIntermBinary *node)
{
TIntermSelection *replacement = NULL;
switch (node->getOp())
{
case EOpLogicalOr:
replacement = UnfoldOR(node->getLeft(), node->getRight());
break;
case EOpLogicalAnd:
replacement = UnfoldAND(node->getLeft(), node->getRight());
break;
default:
break;
}
if (replacement)
{
replacements.push_back(
NodeUpdateEntry(getParentNode(), node, replacement));
}
return true;
}
void UnfoldShortCircuitAST::updateTree()
{
for (size_t ii = 0; ii < replacements.size(); ++ii)
{
const NodeUpdateEntry& entry = replacements[ii];
ASSERT(entry.parent);
bool replaced = entry.parent->replaceChildNode(
entry.original, entry.replacement);
ASSERT(replaced);
// In AST traversing, a parent is visited before its children.
// After we replace a node, if an immediate child is to
// be replaced, we need to make sure we don't update the replaced
// node; instead, we update the replacement node.
for (size_t jj = ii + 1; jj < replacements.size(); ++jj)
{
NodeUpdateEntry& entry2 = replacements[jj];
if (entry2.parent == entry.original)
entry2.parent = entry.replacement;
}
}
}
//
// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// UnfoldShortCircuitAST is an AST traverser to replace short-circuiting
// operations with ternary operations.
//
#ifndef COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
#define COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
#include "common/angleutils.h"
#include "compiler/intermediate.h"
// This traverser identifies all the short circuit binary nodes that need to
// be replaced, and creates the corresponding replacement nodes. However,
// the actual replacements happen after the traverse through updateTree().
class UnfoldShortCircuitAST : public TIntermTraverser
{
public:
UnfoldShortCircuitAST() { }
virtual bool visitBinary(Visit visit, TIntermBinary *);
void updateTree();
private:
struct NodeUpdateEntry
{
NodeUpdateEntry(TIntermNode *_parent,
TIntermNode *_original,
TIntermNode *_replacement)
: parent(_parent),
original(_original),
replacement(_replacement) {}
TIntermNode *parent;
TIntermNode *original;
TIntermNode *replacement;
};
// During traversing, save all the replacements that need to happen;
// then replace them by calling updateNodes().
std::vector<NodeUpdateEntry> replacements;
DISALLOW_COPY_AND_ASSIGN(UnfoldShortCircuitAST);
};
#endif // COMPILER_UNFOLD_SHORT_CIRCUIT_AST_H_
...@@ -227,6 +227,11 @@ public: ...@@ -227,6 +227,11 @@ public:
virtual TIntermSymbol* getAsSymbolNode() { return 0; } virtual TIntermSymbol* getAsSymbolNode() { return 0; }
virtual TIntermLoop* getAsLoopNode() { return 0; } virtual TIntermLoop* getAsLoopNode() { return 0; }
// Replace a child node. Return true if |original| is a child
// node and it is replaced; otherwise, return false.
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement) = 0;
protected: protected:
TSourceLoc line; TSourceLoc line;
}; };
...@@ -295,6 +300,8 @@ public: ...@@ -295,6 +300,8 @@ public:
virtual TIntermLoop* getAsLoopNode() { return this; } virtual TIntermLoop* getAsLoopNode() { return this; }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
TLoopType getType() const { return type; } TLoopType getType() const { return type; }
TIntermNode* getInit() { return init; } TIntermNode* getInit() { return init; }
...@@ -325,6 +332,8 @@ public: ...@@ -325,6 +332,8 @@ public:
expression(e) { } expression(e) { }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
TOperator getFlowOp() { return flowOp; } TOperator getFlowOp() { return flowOp; }
TIntermTyped* getExpression() { return expression; } TIntermTyped* getExpression() { return expression; }
...@@ -355,6 +364,7 @@ public: ...@@ -355,6 +364,7 @@ public:
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual TIntermSymbol* getAsSymbolNode() { return this; } virtual TIntermSymbol* getAsSymbolNode() { return this; }
virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
protected: protected:
int id; int id;
...@@ -374,6 +384,7 @@ public: ...@@ -374,6 +384,7 @@ public:
virtual TIntermConstantUnion* getAsConstantUnion() { return this; } virtual TIntermConstantUnion* getAsConstantUnion() { return this; }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual bool replaceChildNode(TIntermNode *, TIntermNode *) { return false; }
TIntermTyped* fold(TOperator, TIntermTyped*, TInfoSink&); TIntermTyped* fold(TOperator, TIntermTyped*, TInfoSink&);
...@@ -407,6 +418,8 @@ public: ...@@ -407,6 +418,8 @@ public:
virtual TIntermBinary* getAsBinaryNode() { return this; } virtual TIntermBinary* getAsBinaryNode() { return this; }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
void setLeft(TIntermTyped* n) { left = n; } void setLeft(TIntermTyped* n) { left = n; }
void setRight(TIntermTyped* n) { right = n; } void setRight(TIntermTyped* n) { right = n; }
...@@ -435,6 +448,8 @@ public: ...@@ -435,6 +448,8 @@ public:
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual TIntermUnary* getAsUnaryNode() { return this; } virtual TIntermUnary* getAsUnaryNode() { return this; }
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
void setOperand(TIntermTyped* o) { operand = o; } void setOperand(TIntermTyped* o) { operand = o; }
TIntermTyped* getOperand() { return operand; } TIntermTyped* getOperand() { return operand; }
...@@ -465,6 +480,8 @@ public: ...@@ -465,6 +480,8 @@ public:
virtual TIntermAggregate* getAsAggregate() { return this; } virtual TIntermAggregate* getAsAggregate() { return this; }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
TIntermSequence& getSequence() { return sequence; } TIntermSequence& getSequence() { return sequence; }
...@@ -508,6 +525,8 @@ public: ...@@ -508,6 +525,8 @@ public:
TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB) {} TIntermTyped(type), condition(cond), trueBlock(trueB), falseBlock(falseB) {}
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual bool replaceChildNode(
TIntermNode *original, TIntermNode *replacement);
bool usesTernaryOperator() const { return getBasicType() != EbtVoid; } bool usesTernaryOperator() const { return getBasicType() != EbtVoid; }
TIntermNode* getCondition() const { return condition; } TIntermNode* getCondition() const { return condition; }
...@@ -547,7 +566,7 @@ public: ...@@ -547,7 +566,7 @@ public:
rightToLeft(rightToLeft), rightToLeft(rightToLeft),
depth(0), depth(0),
maxDepth(0) {} maxDepth(0) {}
virtual ~TIntermTraverser() {}; virtual ~TIntermTraverser() {}
virtual void visitSymbol(TIntermSymbol*) {} virtual void visitSymbol(TIntermSymbol*) {}
virtual void visitConstantUnion(TIntermConstantUnion*) {} virtual void visitConstantUnion(TIntermConstantUnion*) {}
...@@ -559,8 +578,24 @@ public: ...@@ -559,8 +578,24 @@ public:
virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;} virtual bool visitBranch(Visit visit, TIntermBranch*) {return true;}
int getMaxDepth() const {return maxDepth;} int getMaxDepth() const {return maxDepth;}
void incrementDepth() {depth++; maxDepth = std::max(maxDepth, depth); }
void decrementDepth() {depth--;} void incrementDepth(TIntermNode *current)
{
depth++;
maxDepth = std::max(maxDepth, depth);
path.push_back(current);
}
void decrementDepth()
{
depth--;
path.pop_back();
}
TIntermNode *getParentNode()
{
return path.size() == 0 ? NULL : path.back();
}
// Return the original name if hash function pointer is NULL; // Return the original name if hash function pointer is NULL;
// otherwise return the hashed name. // otherwise return the hashed name.
...@@ -574,6 +609,9 @@ public: ...@@ -574,6 +609,9 @@ public:
protected: protected:
int depth; int depth;
int maxDepth; int maxDepth;
// All the nodes from root to the current node's parent during traversing.
TVector<TIntermNode *> path;
}; };
#endif // __INTERMEDIATE_H #endif // __INTERMEDIATE_H
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment