Commit d0bad2c7 by Olli Etuaho Committed by Commit Bot

Split ternary node class from TIntermSelection

Ternary operator nodes are typed parts of expressions, they always have two children and the children are also guaranteed to be TIntermTyped. "If" selection nodes can't be a part of an expression, they can have either one or two children and the children are code blocks. Due to all of these differences it makes sense to store these using two different AST node classes. BUG=angleproject:1490 TEST=angle_unittests Change-Id: I913ab1d806e3cdb5c21106f078cc9c0b6c72ac54 Reviewed-on: https://chromium-review.googlesource.com/384512 Commit-Queue: Olli Etuaho <oetuaho@nvidia.com> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent 95738972
...@@ -292,6 +292,14 @@ void TIntermAggregate::setBuiltInFunctionPrecision() ...@@ -292,6 +292,14 @@ void TIntermAggregate::setBuiltInFunctionPrecision()
mType.setPrecision(precision); mType.setPrecision(precision);
} }
bool TIntermTernary::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
REPLACE_IF_IS(mCondition, TIntermTyped, original, replacement);
REPLACE_IF_IS(mTrueExpression, TIntermTyped, original, replacement);
REPLACE_IF_IS(mFalseExpression, TIntermTyped, original, replacement);
return false;
}
bool TIntermSelection::replaceChildNode( bool TIntermSelection::replaceChildNode(
TIntermNode *original, TIntermNode *replacement) TIntermNode *original, TIntermNode *replacement)
{ {
...@@ -456,20 +464,15 @@ TIntermUnary::TIntermUnary(const TIntermUnary &node) ...@@ -456,20 +464,15 @@ TIntermUnary::TIntermUnary(const TIntermUnary &node)
mOperand = operandCopy; mOperand = operandCopy;
} }
TIntermSelection::TIntermSelection(const TIntermSelection &node) : TIntermTyped(node) TIntermTernary::TIntermTernary(const TIntermTernary &node) : TIntermTyped(node)
{ {
// Only supported for ternary nodes, not if statements.
TIntermTyped *trueTyped = node.mTrueBlock->getAsTyped();
TIntermTyped *falseTyped = node.mFalseBlock->getAsTyped();
ASSERT(trueTyped != nullptr);
ASSERT(falseTyped != nullptr);
TIntermTyped *conditionCopy = node.mCondition->deepCopy(); TIntermTyped *conditionCopy = node.mCondition->deepCopy();
TIntermTyped *trueCopy = trueTyped->deepCopy(); TIntermTyped *trueCopy = node.mTrueExpression->deepCopy();
TIntermTyped *falseCopy = falseTyped->deepCopy(); TIntermTyped *falseCopy = node.mFalseExpression->deepCopy();
ASSERT(conditionCopy != nullptr && trueCopy != nullptr && falseCopy != nullptr); ASSERT(conditionCopy != nullptr && trueCopy != nullptr && falseCopy != nullptr);
mCondition = conditionCopy; mCondition = conditionCopy;
mTrueBlock = trueCopy; mTrueExpression = trueCopy;
mFalseBlock = falseCopy; mFalseExpression = falseCopy;
} }
bool TIntermOperator::isAssignment() const bool TIntermOperator::isAssignment() const
...@@ -692,6 +695,31 @@ TIntermBinary::TIntermBinary(TOperator op, TIntermTyped *left, TIntermTyped *rig ...@@ -692,6 +695,31 @@ TIntermBinary::TIntermBinary(TOperator op, TIntermTyped *left, TIntermTyped *rig
promote(); promote();
} }
TIntermTernary::TIntermTernary(TIntermTyped *cond,
TIntermTyped *trueExpression,
TIntermTyped *falseExpression)
: TIntermTyped(trueExpression->getType()),
mCondition(cond),
mTrueExpression(trueExpression),
mFalseExpression(falseExpression)
{
getTypePointer()->setQualifier(
TIntermTernary::DetermineQualifier(cond, trueExpression, falseExpression));
}
// static
TQualifier TIntermTernary::DetermineQualifier(TIntermTyped *cond,
TIntermTyped *trueExpression,
TIntermTyped *falseExpression)
{
if (cond->getQualifier() == EvqConst && trueExpression->getQualifier() == EvqConst &&
falseExpression->getQualifier() == EvqConst)
{
return EvqConst;
}
return EvqTemporary;
}
// //
// Establishes the type of the resultant operation, as well as // Establishes the type of the resultant operation, as well as
// makes the operator the correct one for the operands. // makes the operator the correct one for the operands.
......
...@@ -34,6 +34,7 @@ class TIntermAggregate; ...@@ -34,6 +34,7 @@ class TIntermAggregate;
class TIntermBinary; class TIntermBinary;
class TIntermUnary; class TIntermUnary;
class TIntermConstantUnion; class TIntermConstantUnion;
class TIntermTernary;
class TIntermSelection; class TIntermSelection;
class TIntermSwitch; class TIntermSwitch;
class TIntermCase; class TIntermCase;
...@@ -93,6 +94,7 @@ class TIntermNode : angle::NonCopyable ...@@ -93,6 +94,7 @@ class TIntermNode : angle::NonCopyable
virtual TIntermAggregate *getAsAggregate() { return 0; } virtual TIntermAggregate *getAsAggregate() { return 0; }
virtual TIntermBinary *getAsBinaryNode() { return 0; } virtual TIntermBinary *getAsBinaryNode() { return 0; }
virtual TIntermUnary *getAsUnaryNode() { return 0; } virtual TIntermUnary *getAsUnaryNode() { return 0; }
virtual TIntermTernary *getAsTernaryNode() { return nullptr; }
virtual TIntermSelection *getAsSelectionNode() { return 0; } virtual TIntermSelection *getAsSelectionNode() { return 0; }
virtual TIntermSwitch *getAsSwitchNode() { return 0; } virtual TIntermSwitch *getAsSwitchNode() { return 0; }
virtual TIntermCase *getAsCaseNode() { return 0; } virtual TIntermCase *getAsCaseNode() { return 0; }
...@@ -567,35 +569,53 @@ class TIntermAggregate : public TIntermOperator ...@@ -567,35 +569,53 @@ class TIntermAggregate : public TIntermOperator
TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private! TIntermAggregate(const TIntermAggregate &node); // note: not deleted, just private!
}; };
// // For ternary operators like a ? b : c.
class TIntermTernary : public TIntermTyped
{
public:
TIntermTernary(TIntermTyped *cond, TIntermTyped *trueExpression, TIntermTyped *falseExpression);
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermTyped *getCondition() const { return mCondition; }
TIntermTyped *getTrueExpression() const { return mTrueExpression; }
TIntermTyped *getFalseExpression() const { return mFalseExpression; }
TIntermTernary *getAsTernaryNode() override { return this; }
TIntermTyped *deepCopy() const override { return new TIntermTernary(*this); }
bool hasSideEffects() const override
{
return mCondition->hasSideEffects() || mTrueExpression->hasSideEffects() ||
mFalseExpression->hasSideEffects();
}
static TQualifier DetermineQualifier(TIntermTyped *cond,
TIntermTyped *trueExpression,
TIntermTyped *falseExpression);
private:
TIntermTernary(const TIntermTernary &node); // Note: not deleted, just private!
TIntermTyped *mCondition;
TIntermTyped *mTrueExpression;
TIntermTyped *mFalseExpression;
};
// For if tests. // For if tests.
// class TIntermSelection : public TIntermNode
class TIntermSelection : public TIntermTyped
{ {
public: public:
TIntermSelection(TIntermTyped *cond, TIntermNode *trueB, TIntermNode *falseB) TIntermSelection(TIntermTyped *cond, TIntermNode *trueB, TIntermNode *falseB)
: TIntermTyped(TType(EbtVoid, EbpUndefined)), : TIntermNode(), mCondition(cond), mTrueBlock(trueB), mFalseBlock(falseB)
mCondition(cond), {
mTrueBlock(trueB), }
mFalseBlock(falseB) {}
TIntermSelection(TIntermTyped *cond, TIntermNode *trueB, TIntermNode *falseB,
const TType &type)
: TIntermTyped(type),
mCondition(cond),
mTrueBlock(trueB),
mFalseBlock(falseB) {}
// Note: only supported for ternary operator nodes.
TIntermTyped *deepCopy() const override { return new TIntermSelection(*this); }
void traverse(TIntermTraverser *it) override; void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override; bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
// Conservatively assume selections have side-effects TIntermTyped *getCondition() const { return mCondition; }
bool hasSideEffects() const override { return true; }
bool usesTernaryOperator() const { return getBasicType() != EbtVoid; }
TIntermNode *getCondition() const { return mCondition; }
TIntermNode *getTrueBlock() const { return mTrueBlock; } TIntermNode *getTrueBlock() const { return mTrueBlock; }
TIntermNode *getFalseBlock() const { return mFalseBlock; } TIntermNode *getFalseBlock() const { return mFalseBlock; }
TIntermSelection *getAsSelectionNode() override { return this; } TIntermSelection *getAsSelectionNode() override { return this; }
...@@ -604,9 +624,6 @@ class TIntermSelection : public TIntermTyped ...@@ -604,9 +624,6 @@ class TIntermSelection : public TIntermTyped
TIntermTyped *mCondition; TIntermTyped *mCondition;
TIntermNode *mTrueBlock; TIntermNode *mTrueBlock;
TIntermNode *mFalseBlock; TIntermNode *mFalseBlock;
private:
TIntermSelection(const TIntermSelection &node); // Note: not deleted, just private!
}; };
// //
...@@ -692,6 +709,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -692,6 +709,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void visitConstantUnion(TIntermConstantUnion *node) {} virtual void visitConstantUnion(TIntermConstantUnion *node) {}
virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; } virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; }
virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; } virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; }
virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; }
virtual bool visitSelection(Visit visit, TIntermSelection *node) { return true; } virtual bool visitSelection(Visit visit, TIntermSelection *node) { return true; }
virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; } virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
virtual bool visitCase(Visit visit, TIntermCase *node) { return true; } virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
...@@ -707,6 +725,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -707,6 +725,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void traverseConstantUnion(TIntermConstantUnion *node); virtual void traverseConstantUnion(TIntermConstantUnion *node);
virtual void traverseBinary(TIntermBinary *node); virtual void traverseBinary(TIntermBinary *node);
virtual void traverseUnary(TIntermUnary *node); virtual void traverseUnary(TIntermUnary *node);
virtual void traverseTernary(TIntermTernary *node);
virtual void traverseSelection(TIntermSelection *node); virtual void traverseSelection(TIntermSelection *node);
virtual void traverseSwitch(TIntermSwitch *node); virtual void traverseSwitch(TIntermSwitch *node);
virtual void traverseCase(TIntermCase *node); virtual void traverseCase(TIntermCase *node);
...@@ -994,6 +1013,7 @@ class TMaxDepthTraverser : public TIntermTraverser ...@@ -994,6 +1013,7 @@ class TMaxDepthTraverser : public TIntermTraverser
bool visitBinary(Visit, TIntermBinary *) override { return depthCheck(); } bool visitBinary(Visit, TIntermBinary *) override { return depthCheck(); }
bool visitUnary(Visit, TIntermUnary *) override { return depthCheck(); } bool visitUnary(Visit, TIntermUnary *) override { return depthCheck(); }
bool visitTernary(Visit, TIntermTernary *) override { return depthCheck(); }
bool visitSelection(Visit, TIntermSelection *) override { return depthCheck(); } bool visitSelection(Visit, TIntermSelection *) override { return depthCheck(); }
bool visitAggregate(Visit, TIntermAggregate *) override { return depthCheck(); } bool visitAggregate(Visit, TIntermAggregate *) override { return depthCheck(); }
bool visitLoop(Visit, TIntermLoop *) override { return depthCheck(); } bool visitLoop(Visit, TIntermLoop *) override { return depthCheck(); }
......
...@@ -105,14 +105,11 @@ bool IntermNodePatternMatcher::match(TIntermAggregate *node, TIntermNode *parent ...@@ -105,14 +105,11 @@ bool IntermNodePatternMatcher::match(TIntermAggregate *node, TIntermNode *parent
return false; return false;
} }
bool IntermNodePatternMatcher::match(TIntermSelection *node) bool IntermNodePatternMatcher::match(TIntermTernary *node)
{ {
if ((mMask & kUnfoldedShortCircuitExpression) != 0) if ((mMask & kUnfoldedShortCircuitExpression) != 0)
{ {
if (node->usesTernaryOperator()) return true;
{
return true;
}
} }
return false; return false;
} }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
class TIntermAggregate; class TIntermAggregate;
class TIntermBinary; class TIntermBinary;
class TIntermNode; class TIntermNode;
class TIntermSelection; class TIntermTernary;
class IntermNodePatternMatcher class IntermNodePatternMatcher
{ {
...@@ -42,7 +42,7 @@ class IntermNodePatternMatcher ...@@ -42,7 +42,7 @@ class IntermNodePatternMatcher
bool match(TIntermBinary *node, TIntermNode *parentNode, bool isLValueRequiredHere); bool match(TIntermBinary *node, TIntermNode *parentNode, bool isLValueRequiredHere);
bool match(TIntermAggregate *node, TIntermNode *parentNode); bool match(TIntermAggregate *node, TIntermNode *parentNode);
bool match(TIntermSelection *node); bool match(TIntermTernary *node);
private: private:
const unsigned int mMask; const unsigned int mMask;
......
...@@ -33,6 +33,11 @@ void TIntermUnary::traverse(TIntermTraverser *it) ...@@ -33,6 +33,11 @@ void TIntermUnary::traverse(TIntermTraverser *it)
it->traverseUnary(this); it->traverseUnary(this);
} }
void TIntermTernary::traverse(TIntermTraverser *it)
{
it->traverseTernary(this);
}
void TIntermSelection::traverse(TIntermTraverser *it) void TIntermSelection::traverse(TIntermTraverser *it)
{ {
it->traverseSelection(this); it->traverseSelection(this);
...@@ -567,6 +572,31 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node) ...@@ -567,6 +572,31 @@ void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
} }
// //
// Traverse a ternary node. Same comments in binary node apply here.
//
void TIntermTraverser::traverseTernary(TIntermTernary *node)
{
bool visit = true;
if (preVisit)
visit = visitTernary(PreVisit, node);
if (visit)
{
incrementDepth(node);
node->getCondition()->traverse(this);
if (node->getTrueExpression())
node->getTrueExpression()->traverse(this);
if (node->getFalseExpression())
node->getFalseExpression()->traverse(this);
decrementDepth();
}
if (visit && postVisit)
visitTernary(PostVisit, node);
}
//
// Traverse a selection node. Same comments in binary node apply here. // Traverse a selection node. Same comments in binary node apply here.
// //
void TIntermTraverser::traverseSelection(TIntermSelection *node) void TIntermTraverser::traverseSelection(TIntermSelection *node)
......
...@@ -236,43 +236,37 @@ TIntermTyped *TIntermediate::addComma(TIntermTyped *left, ...@@ -236,43 +236,37 @@ TIntermTyped *TIntermediate::addComma(TIntermTyped *left,
return commaNode; return commaNode;
} }
//
// For "?:" test nodes. There are three children; a condition, // For "?:" test nodes. There are three children; a condition,
// a true path, and a false path. The two paths are specified // a true path, and a false path. The two paths are specified
// as separate parameters. // as separate parameters.
// //
// Returns the selection node created, or one of trueBlock and falseBlock if the expression could be folded. // Returns the ternary node created, or one of trueExpression and falseExpression if the expression
// // could be folded.
TIntermTyped *TIntermediate::addSelection(TIntermTyped *cond, TIntermTyped *trueBlock, TIntermTyped *falseBlock, TIntermTyped *TIntermediate::AddTernarySelection(TIntermTyped *cond,
const TSourceLoc &line) TIntermTyped *trueExpression,
TIntermTyped *falseExpression,
const TSourceLoc &line)
{ {
TQualifier resultQualifier = EvqTemporary;
if (cond->getQualifier() == EvqConst && trueBlock->getQualifier() == EvqConst &&
falseBlock->getQualifier() == EvqConst)
{
resultQualifier = EvqConst;
}
// Note that the node resulting from here can be a constant union without being qualified as // Note that the node resulting from here can be a constant union without being qualified as
// constant. // constant.
if (cond->getAsConstantUnion()) if (cond->getAsConstantUnion())
{ {
TQualifier resultQualifier =
TIntermTernary::DetermineQualifier(cond, trueExpression, falseExpression);
if (cond->getAsConstantUnion()->getBConst(0)) if (cond->getAsConstantUnion()->getBConst(0))
{ {
trueBlock->getTypePointer()->setQualifier(resultQualifier); trueExpression->getTypePointer()->setQualifier(resultQualifier);
return trueBlock; return trueExpression;
} }
else else
{ {
falseBlock->getTypePointer()->setQualifier(resultQualifier); falseExpression->getTypePointer()->setQualifier(resultQualifier);
return falseBlock; return falseExpression;
} }
} }
// // Make a ternary node.
// Make a selection node. TIntermTernary *node = new TIntermTernary(cond, trueExpression, falseExpression);
//
TIntermSelection *node = new TIntermSelection(cond, trueBlock, falseBlock, trueBlock->getType());
node->getTypePointer()->setQualifier(resultQualifier);
node->setLine(line); node->setLine(line);
return node; return node;
......
...@@ -38,9 +38,11 @@ class TIntermediate ...@@ -38,9 +38,11 @@ class TIntermediate
TIntermAggregate *makeAggregate(TIntermNode *node, const TSourceLoc &); TIntermAggregate *makeAggregate(TIntermNode *node, const TSourceLoc &);
TIntermAggregate *ensureSequence(TIntermNode *node); TIntermAggregate *ensureSequence(TIntermNode *node);
TIntermAggregate *setAggregateOperator(TIntermNode *, TOperator, const TSourceLoc &); TIntermAggregate *setAggregateOperator(TIntermNode *, TOperator, const TSourceLoc &);
TIntermNode *addSelection(TIntermTyped *cond, TIntermNodePair code, const TSourceLoc &); TIntermNode *addSelection(TIntermTyped *cond, TIntermNodePair code, const TSourceLoc &line);
TIntermTyped *addSelection(TIntermTyped *cond, TIntermTyped *trueBlock, TIntermTyped *falseBlock, static TIntermTyped *AddTernarySelection(TIntermTyped *cond,
const TSourceLoc &line); TIntermTyped *trueExpression,
TIntermTyped *falseExpression,
const TSourceLoc &line);
TIntermSwitch *addSwitch( TIntermSwitch *addSwitch(
TIntermTyped *init, TIntermAggregate *statementList, const TSourceLoc &line); TIntermTyped *init, TIntermAggregate *statementList, const TSourceLoc &line);
TIntermCase *addCase( TIntermCase *addCase(
......
...@@ -27,11 +27,9 @@ bool isSingleStatement(TIntermNode *node) ...@@ -27,11 +27,9 @@ bool isSingleStatement(TIntermNode *node)
return (aggregate->getOp() != EOpFunction) && return (aggregate->getOp() != EOpFunction) &&
(aggregate->getOp() != EOpSequence); (aggregate->getOp() != EOpSequence);
} }
else if (const TIntermSelection *selection = node->getAsSelectionNode()) else if (node->getAsSelectionNode())
{ {
// Ternary operators are usually part of an assignment operator. return false;
// This handles those rare cases in which they are all by themselves.
return selection->usesTernaryOperator();
} }
else if (node->getAsLoopNode()) else if (node->getAsLoopNode())
{ {
...@@ -711,40 +709,40 @@ bool TOutputGLSLBase::visitUnary(Visit visit, TIntermUnary *node) ...@@ -711,40 +709,40 @@ bool TOutputGLSLBase::visitUnary(Visit visit, TIntermUnary *node)
return true; return true;
} }
bool TOutputGLSLBase::visitTernary(Visit visit, TIntermTernary *node)
{
TInfoSinkBase &out = objSink();
// Notice two brackets at the beginning and end. The outer ones
// encapsulate the whole ternary expression. This preserves the
// order of precedence when ternary expressions are used in a
// compound expression, i.e., c = 2 * (a < b ? 1 : 2).
out << "((";
node->getCondition()->traverse(this);
out << ") ? (";
node->getTrueExpression()->traverse(this);
out << ") : (";
node->getFalseExpression()->traverse(this);
out << "))";
return false;
}
bool TOutputGLSLBase::visitSelection(Visit visit, TIntermSelection *node) bool TOutputGLSLBase::visitSelection(Visit visit, TIntermSelection *node)
{ {
TInfoSinkBase &out = objSink(); TInfoSinkBase &out = objSink();
if (node->usesTernaryOperator()) out << "if (";
{ node->getCondition()->traverse(this);
// Notice two brackets at the beginning and end. The outer ones out << ")\n";
// encapsulate the whole ternary expression. This preserves the
// order of precedence when ternary expressions are used in a
// compound expression, i.e., c = 2 * (a < b ? 1 : 2).
out << "((";
node->getCondition()->traverse(this);
out << ") ? (";
node->getTrueBlock()->traverse(this);
out << ") : (";
node->getFalseBlock()->traverse(this);
out << "))";
}
else
{
out << "if (";
node->getCondition()->traverse(this);
out << ")\n";
incrementDepth(node); incrementDepth(node);
visitCodeBlock(node->getTrueBlock()); visitCodeBlock(node->getTrueBlock());
if (node->getFalseBlock()) if (node->getFalseBlock())
{ {
out << "else\n"; out << "else\n";
visitCodeBlock(node->getFalseBlock()); visitCodeBlock(node->getFalseBlock());
}
decrementDepth();
} }
decrementDepth();
return false; return false;
} }
......
...@@ -44,6 +44,7 @@ class TOutputGLSLBase : public TIntermTraverser ...@@ -44,6 +44,7 @@ class TOutputGLSLBase : public TIntermTraverser
void visitConstantUnion(TIntermConstantUnion *node) override; void visitConstantUnion(TIntermConstantUnion *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitUnary(Visit visit, TIntermUnary *node) override; bool visitUnary(Visit visit, TIntermUnary *node) override;
bool visitTernary(Visit visit, TIntermTernary *node) override;
bool visitSelection(Visit visit, TIntermSelection *node) override; bool visitSelection(Visit visit, TIntermSelection *node) override;
bool visitSwitch(Visit visit, TIntermSwitch *node) override; bool visitSwitch(Visit visit, TIntermSwitch *node) override;
bool visitCase(Visit visit, TIntermCase *node) override; bool visitCase(Visit visit, TIntermCase *node) override;
......
...@@ -1456,9 +1456,8 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1456,9 +1456,8 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
// case statements into non-empty case statements, disallowing fall-through from them. // case statements into non-empty case statements, disallowing fall-through from them.
// Also no need to output ; after selection (if) statements or sequences. This is done just // Also no need to output ; after selection (if) statements or sequences. This is done just
// for code clarity. // for code clarity.
TIntermSelection *asSelection = (*sit)->getAsSelectionNode(); if ((*sit)->getAsCaseNode() == nullptr && (*sit)->getAsSelectionNode() == nullptr &&
ASSERT(asSelection == nullptr || !asSelection->usesTernaryOperator()); !IsSequence(*sit))
if ((*sit)->getAsCaseNode() == nullptr && asSelection == nullptr && !IsSequence(*sit))
out << ";\n"; out << ";\n";
} }
...@@ -1982,11 +1981,18 @@ void OutputHLSL::writeSelection(TInfoSinkBase &out, TIntermSelection *node) ...@@ -1982,11 +1981,18 @@ void OutputHLSL::writeSelection(TInfoSinkBase &out, TIntermSelection *node)
} }
} }
bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
{
// Ternary ops should have been already converted to something else in the AST. HLSL ternary
// operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
UNREACHABLE();
return false;
}
bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node) bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
ASSERT(!node->usesTernaryOperator());
ASSERT(mInsideFunction); ASSERT(mInsideFunction);
// D3D errors when there is a gradient operation in a loop in an unflattened if. // D3D errors when there is a gradient operation in a loop in an unflattened if.
......
...@@ -58,6 +58,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -58,6 +58,7 @@ class OutputHLSL : public TIntermTraverser
void visitConstantUnion(TIntermConstantUnion*); void visitConstantUnion(TIntermConstantUnion*);
bool visitBinary(Visit visit, TIntermBinary*); bool visitBinary(Visit visit, TIntermBinary*);
bool visitUnary(Visit visit, TIntermUnary*); bool visitUnary(Visit visit, TIntermUnary*);
bool visitTernary(Visit visit, TIntermTernary *);
bool visitSelection(Visit visit, TIntermSelection*); bool visitSelection(Visit visit, TIntermSelection*);
bool visitSwitch(Visit visit, TIntermSwitch *); bool visitSwitch(Visit visit, TIntermSwitch *);
bool visitCase(Visit visit, TIntermCase *); bool visitCase(Visit visit, TIntermCase *);
......
...@@ -3903,34 +3903,35 @@ TIntermTyped *TParseContext::addFunctionCallOrMethod(TFunction *fnCall, ...@@ -3903,34 +3903,35 @@ TIntermTyped *TParseContext::addFunctionCallOrMethod(TFunction *fnCall,
} }
TIntermTyped *TParseContext::addTernarySelection(TIntermTyped *cond, TIntermTyped *TParseContext::addTernarySelection(TIntermTyped *cond,
TIntermTyped *trueBlock, TIntermTyped *trueExpression,
TIntermTyped *falseBlock, TIntermTyped *falseExpression,
const TSourceLoc &loc) const TSourceLoc &loc)
{ {
checkIsScalarBool(loc, cond); checkIsScalarBool(loc, cond);
if (trueBlock->getType() != falseBlock->getType()) if (trueExpression->getType() != falseExpression->getType())
{ {
binaryOpError(loc, ":", trueBlock->getCompleteString(), falseBlock->getCompleteString()); binaryOpError(loc, ":", trueExpression->getCompleteString(),
return falseBlock; falseExpression->getCompleteString());
return falseExpression;
} }
// ESSL1 sections 5.2 and 5.7: // ESSL1 sections 5.2 and 5.7:
// ESSL3 section 5.7: // ESSL3 section 5.7:
// Ternary operator is not among the operators allowed for structures/arrays. // Ternary operator is not among the operators allowed for structures/arrays.
if (trueBlock->isArray() || trueBlock->getBasicType() == EbtStruct) if (trueExpression->isArray() || trueExpression->getBasicType() == EbtStruct)
{ {
error(loc, "ternary operator is not allowed for structures or arrays", ":"); error(loc, "ternary operator is not allowed for structures or arrays", ":");
return falseBlock; return falseExpression;
} }
// WebGL2 section 5.26, the following results in an error: // WebGL2 section 5.26, the following results in an error:
// "Ternary operator applied to void, arrays, or structs containing arrays" // "Ternary operator applied to void, arrays, or structs containing arrays"
if (mShaderSpec == SH_WEBGL2_SPEC && trueBlock->getBasicType() == EbtVoid) if (mShaderSpec == SH_WEBGL2_SPEC && trueExpression->getBasicType() == EbtVoid)
{ {
error(loc, "ternary operator is not allowed for void", ":"); error(loc, "ternary operator is not allowed for void", ":");
return falseBlock; return falseExpression;
} }
return intermediate.addSelection(cond, trueBlock, falseBlock, loc); return TIntermediate::AddTernarySelection(cond, trueExpression, falseExpression, loc);
} }
// //
......
...@@ -360,8 +360,10 @@ class TParseContext : angle::NonCopyable ...@@ -360,8 +360,10 @@ class TParseContext : angle::NonCopyable
const TSourceLoc &loc, const TSourceLoc &loc,
bool *fatalError); bool *fatalError);
TIntermTyped *addTernarySelection( TIntermTyped *addTernarySelection(TIntermTyped *cond,
TIntermTyped *cond, TIntermTyped *trueBlock, TIntermTyped *falseBlock, const TSourceLoc &line); TIntermTyped *trueExpression,
TIntermTyped *falseExpression,
const TSourceLoc &line);
// TODO(jmadill): make these private // TODO(jmadill): make these private
TIntermediate intermediate; // to build a parse tree TIntermediate intermediate; // to build a parse tree
......
...@@ -62,6 +62,13 @@ bool RemoveSwitchFallThrough::visitUnary(Visit, TIntermUnary *node) ...@@ -62,6 +62,13 @@ bool RemoveSwitchFallThrough::visitUnary(Visit, TIntermUnary *node)
return false; return false;
} }
bool RemoveSwitchFallThrough::visitTernary(Visit, TIntermTernary *node)
{
mPreviousCase->getSequence()->push_back(node);
mLastStatementWasBreak = false;
return false;
}
bool RemoveSwitchFallThrough::visitSelection(Visit, TIntermSelection *node) bool RemoveSwitchFallThrough::visitSelection(Visit, TIntermSelection *node)
{ {
mPreviousCase->getSequence()->push_back(node); mPreviousCase->getSequence()->push_back(node);
......
...@@ -23,6 +23,7 @@ class RemoveSwitchFallThrough : public TIntermTraverser ...@@ -23,6 +23,7 @@ class RemoveSwitchFallThrough : public TIntermTraverser
void visitConstantUnion(TIntermConstantUnion *node) override; void visitConstantUnion(TIntermConstantUnion *node) override;
bool visitBinary(Visit, TIntermBinary *node) override; bool visitBinary(Visit, TIntermBinary *node) override;
bool visitUnary(Visit, TIntermUnary *node) override; bool visitUnary(Visit, TIntermUnary *node) override;
bool visitTernary(Visit visit, TIntermTernary *node) override;
bool visitSelection(Visit visit, TIntermSelection *node) override; bool visitSelection(Visit visit, TIntermSelection *node) override;
bool visitSwitch(Visit, TIntermSwitch *node) override; bool visitSwitch(Visit, TIntermSwitch *node) override;
bool visitCase(Visit, TIntermCase *node) override; bool visitCase(Visit, TIntermCase *node) override;
......
...@@ -36,7 +36,7 @@ class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser ...@@ -36,7 +36,7 @@ class SimplifyLoopConditionsTraverser : public TLValueTrackingTraverser
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitSelection(Visit visit, TIntermSelection *node) override; bool visitTernary(Visit visit, TIntermTernary *node) override;
void nextIteration(); void nextIteration();
bool foundLoopToChange() const { return mFoundLoopToChange; } bool foundLoopToChange() const { return mFoundLoopToChange; }
...@@ -100,14 +100,14 @@ bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggrega ...@@ -100,14 +100,14 @@ bool SimplifyLoopConditionsTraverser::visitAggregate(Visit visit, TIntermAggrega
return !mFoundLoopToChange; return !mFoundLoopToChange;
} }
bool SimplifyLoopConditionsTraverser::visitSelection(Visit visit, TIntermSelection *node) bool SimplifyLoopConditionsTraverser::visitTernary(Visit visit, TIntermTernary *node)
{ {
if (mFoundLoopToChange) if (mFoundLoopToChange)
return false; return false;
// Don't traverse ternary operators outside loop conditions. // Don't traverse ternary operators outside loop conditions.
if (!mInsideLoopConditionOrExpression) if (!mInsideLoopConditionOrExpression)
return !node->usesTernaryOperator(); return false;
mFoundLoopToChange = mConditionsToSimplify.match(node); mFoundLoopToChange = mConditionsToSimplify.match(node);
return !mFoundLoopToChange; return !mFoundLoopToChange;
......
...@@ -26,7 +26,7 @@ class SplitSequenceOperatorTraverser : public TLValueTrackingTraverser ...@@ -26,7 +26,7 @@ class SplitSequenceOperatorTraverser : public TLValueTrackingTraverser
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitSelection(Visit visit, TIntermSelection *node) override; bool visitTernary(Visit visit, TIntermTernary *node) override;
void nextIteration(); void nextIteration();
bool foundExpressionToSplit() const { return mFoundExpressionToSplit; } bool foundExpressionToSplit() const { return mFoundExpressionToSplit; }
...@@ -123,7 +123,7 @@ bool SplitSequenceOperatorTraverser::visitAggregate(Visit visit, TIntermAggregat ...@@ -123,7 +123,7 @@ bool SplitSequenceOperatorTraverser::visitAggregate(Visit visit, TIntermAggregat
return true; return true;
} }
bool SplitSequenceOperatorTraverser::visitSelection(Visit visit, TIntermSelection *node) bool SplitSequenceOperatorTraverser::visitTernary(Visit visit, TIntermTernary *node)
{ {
if (mFoundExpressionToSplit) if (mFoundExpressionToSplit)
return false; return false;
......
...@@ -10,32 +10,30 @@ namespace ...@@ -10,32 +10,30 @@ namespace
{ {
// "x || y" is equivalent to "x ? true : y". // "x || y" is equivalent to "x ? true : y".
TIntermSelection *UnfoldOR(TIntermTyped *x, TIntermTyped *y) TIntermTernary *UnfoldOR(TIntermTyped *x, TIntermTyped *y)
{ {
const TType boolType(EbtBool, EbpUndefined);
TConstantUnion *u = new TConstantUnion; TConstantUnion *u = new TConstantUnion;
u->setBConst(true); u->setBConst(true);
TIntermConstantUnion *trueNode = new TIntermConstantUnion( TIntermConstantUnion *trueNode = new TIntermConstantUnion(
u, TType(EbtBool, EbpUndefined, EvqConst, 1)); u, TType(EbtBool, EbpUndefined, EvqConst, 1));
return new TIntermSelection(x, trueNode, y, boolType); return new TIntermTernary(x, trueNode, y);
} }
// "x && y" is equivalent to "x ? y : false". // "x && y" is equivalent to "x ? y : false".
TIntermSelection *UnfoldAND(TIntermTyped *x, TIntermTyped *y) TIntermTernary *UnfoldAND(TIntermTyped *x, TIntermTyped *y)
{ {
const TType boolType(EbtBool, EbpUndefined);
TConstantUnion *u = new TConstantUnion; TConstantUnion *u = new TConstantUnion;
u->setBConst(false); u->setBConst(false);
TIntermConstantUnion *falseNode = new TIntermConstantUnion( TIntermConstantUnion *falseNode = new TIntermConstantUnion(
u, TType(EbtBool, EbpUndefined, EvqConst, 1)); u, TType(EbtBool, EbpUndefined, EvqConst, 1));
return new TIntermSelection(x, y, falseNode, boolType); return new TIntermTernary(x, y, falseNode);
} }
} // namespace anonymous } // namespace anonymous
bool UnfoldShortCircuitAST::visitBinary(Visit visit, TIntermBinary *node) bool UnfoldShortCircuitAST::visitBinary(Visit visit, TIntermBinary *node)
{ {
TIntermSelection *replacement = NULL; TIntermTernary *replacement = nullptr;
switch (node->getOp()) switch (node->getOp())
{ {
......
...@@ -23,7 +23,7 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser ...@@ -23,7 +23,7 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser
UnfoldShortCircuitTraverser(); UnfoldShortCircuitTraverser();
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitSelection(Visit visit, TIntermSelection *node) override; bool visitTernary(Visit visit, TIntermTernary *node) override;
void nextIteration(); void nextIteration();
bool foundShortCircuit() const { return mFoundShortCircuit; } bool foundShortCircuit() const { return mFoundShortCircuit; }
...@@ -118,7 +118,7 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -118,7 +118,7 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
} }
} }
bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *node) bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node)
{ {
if (mFoundShortCircuit) if (mFoundShortCircuit)
return false; return false;
...@@ -131,8 +131,6 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection * ...@@ -131,8 +131,6 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *
mFoundShortCircuit = true; mFoundShortCircuit = true;
ASSERT(node->usesTernaryOperator());
// Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;" // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
TIntermSequence insertions; TIntermSequence insertions;
...@@ -142,11 +140,11 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection * ...@@ -142,11 +140,11 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *
insertions.push_back(tempDeclaration); insertions.push_back(tempDeclaration);
TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *trueAssignment = createTempAssignment(node->getTrueBlock()->getAsTyped()); TIntermBinary *trueAssignment = createTempAssignment(node->getTrueExpression());
trueBlock->getSequence()->push_back(trueAssignment); trueBlock->getSequence()->push_back(trueAssignment);
TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *falseAssignment = createTempAssignment(node->getFalseBlock()->getAsTyped()); TIntermBinary *falseAssignment = createTempAssignment(node->getFalseExpression());
falseBlock->getSequence()->push_back(falseAssignment); falseBlock->getSequence()->push_back(falseAssignment);
TIntermSelection *ifNode = TIntermSelection *ifNode =
......
...@@ -59,6 +59,14 @@ bool ValidateSwitch::visitUnary(Visit, TIntermUnary *) ...@@ -59,6 +59,14 @@ bool ValidateSwitch::visitUnary(Visit, TIntermUnary *)
{ {
if (!mFirstCaseFound) if (!mFirstCaseFound)
mStatementBeforeCase = true; mStatementBeforeCase = true;
mLastStatementWasCase = false;
return true;
}
bool ValidateSwitch::visitTernary(Visit, TIntermTernary *)
{
if (!mFirstCaseFound)
mStatementBeforeCase = true;
mLastStatementWasCase = false; mLastStatementWasCase = false;
return true; return true;
} }
......
...@@ -23,6 +23,7 @@ class ValidateSwitch : public TIntermTraverser ...@@ -23,6 +23,7 @@ class ValidateSwitch : public TIntermTraverser
void visitConstantUnion(TIntermConstantUnion *) override; void visitConstantUnion(TIntermConstantUnion *) override;
bool visitBinary(Visit, TIntermBinary *) override; bool visitBinary(Visit, TIntermBinary *) override;
bool visitUnary(Visit, TIntermUnary *) override; bool visitUnary(Visit, TIntermUnary *) override;
bool visitTernary(Visit, TIntermTernary *) override;
bool visitSelection(Visit visit, TIntermSelection *) override; bool visitSelection(Visit visit, TIntermSelection *) override;
bool visitSwitch(Visit, TIntermSwitch *) override; bool visitSwitch(Visit, TIntermSwitch *) override;
bool visitCase(Visit, TIntermCase *node) override; bool visitCase(Visit, TIntermCase *node) override;
......
...@@ -44,10 +44,12 @@ class TOutputTraverser : public TIntermTraverser ...@@ -44,10 +44,12 @@ class TOutputTraverser : public TIntermTraverser
void visitConstantUnion(TIntermConstantUnion *) override; void visitConstantUnion(TIntermConstantUnion *) override;
bool visitBinary(Visit visit, TIntermBinary *) override; bool visitBinary(Visit visit, TIntermBinary *) override;
bool visitUnary(Visit visit, TIntermUnary *) override; bool visitUnary(Visit visit, TIntermUnary *) override;
bool visitTernary(Visit visit, TIntermTernary *node) override;
bool visitSelection(Visit visit, TIntermSelection *) override; bool visitSelection(Visit visit, TIntermSelection *) override;
bool visitAggregate(Visit visit, TIntermAggregate *) override; bool visitAggregate(Visit visit, TIntermAggregate *) override;
bool visitLoop(Visit visit, TIntermLoop *) override; bool visitLoop(Visit visit, TIntermLoop *) override;
bool visitBranch(Visit visit, TIntermBranch *) override; bool visitBranch(Visit visit, TIntermBranch *) override;
// TODO: Add missing visit functions
}; };
// //
...@@ -457,13 +459,13 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -457,13 +459,13 @@ bool TOutputTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
return true; return true;
} }
bool TOutputTraverser::visitSelection(Visit visit, TIntermSelection *node) bool TOutputTraverser::visitTernary(Visit visit, TIntermTernary *node)
{ {
TInfoSinkBase &out = sink; TInfoSinkBase &out = sink;
OutputTreeText(out, node, mDepth); OutputTreeText(out, node, mDepth);
out << "Test condition and select"; out << "Ternary selection";
out << " (" << node->getCompleteString() << ")\n"; out << " (" << node->getCompleteString() << ")\n";
++mDepth; ++mDepth;
...@@ -473,6 +475,38 @@ bool TOutputTraverser::visitSelection(Visit visit, TIntermSelection *node) ...@@ -473,6 +475,38 @@ bool TOutputTraverser::visitSelection(Visit visit, TIntermSelection *node)
node->getCondition()->traverse(this); node->getCondition()->traverse(this);
OutputTreeText(sink, node, mDepth); OutputTreeText(sink, node, mDepth);
if (node->getTrueExpression())
{
out << "true case\n";
node->getTrueExpression()->traverse(this);
}
if (node->getFalseExpression())
{
OutputTreeText(sink, node, mDepth);
out << "false case\n";
node->getFalseExpression()->traverse(this);
}
--mDepth;
return false;
}
bool TOutputTraverser::visitSelection(Visit visit, TIntermSelection *node)
{
TInfoSinkBase &out = sink;
OutputTreeText(out, node, mDepth);
out << "If test\n";
++mDepth;
OutputTreeText(sink, node, mDepth);
out << "Condition\n";
node->getCondition()->traverse(this);
OutputTreeText(sink, node, mDepth);
if (node->getTrueBlock()) if (node->getTrueBlock())
{ {
out << "true case\n"; out << "true case\n";
......
...@@ -207,24 +207,24 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode) ...@@ -207,24 +207,24 @@ TEST_F(IntermNodeTest, DeepCopyAggregateNode)
} }
} }
// Check that the deep copy of a selection node is an actual copy with the same attributes as the // Check that the deep copy of a ternary node is an actual copy with the same attributes as the
// original. Child nodes also need to be copies with the same attributes as the original children. // original. Child nodes also need to be copies with the same attributes as the original children.
TEST_F(IntermNodeTest, DeepCopySelectionNode) TEST_F(IntermNodeTest, DeepCopyTernaryNode)
{ {
TType type(EbtFloat, EbpHigh); TType type(EbtFloat, EbpHigh);
TIntermSelection *original = new TIntermSelection( TIntermTernary *original = new TIntermTernary(createTestSymbol(TType(EbtBool, EbpUndefined)),
createTestSymbol(TType(EbtBool, EbpUndefined)), createTestSymbol(), createTestSymbol()); createTestSymbol(), createTestSymbol());
original->setLine(getTestSourceLoc()); original->setLine(getTestSourceLoc());
TIntermTyped *copyTyped = original->deepCopy(); TIntermTyped *copyTyped = original->deepCopy();
TIntermSelection *copy = copyTyped->getAsSelectionNode(); TIntermTernary *copy = copyTyped->getAsTernaryNode();
ASSERT_NE(nullptr, copy); ASSERT_NE(nullptr, copy);
ASSERT_NE(original, copy); ASSERT_NE(original, copy);
checkTestSourceLoc(copy->getLine()); checkTestSourceLoc(copy->getLine());
checkTypeEqualWithQualifiers(original->getType(), copy->getType()); checkTypeEqualWithQualifiers(original->getType(), copy->getType());
checkSymbolCopy(original->getCondition(), copy->getCondition()); checkSymbolCopy(original->getCondition(), copy->getCondition());
checkSymbolCopy(original->getTrueBlock(), copy->getTrueBlock()); checkSymbolCopy(original->getTrueExpression(), copy->getTrueExpression());
checkSymbolCopy(original->getFalseBlock(), copy->getFalseBlock()); checkSymbolCopy(original->getFalseExpression(), copy->getFalseExpression());
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment