Commit 00f6fbbe by Olli Etuaho Committed by Commit Bot

Add IntermNodePatternMatcher helper class

This will enable sharing code between different AST traversers that apply transformations on similar node structures. This will make the code more maintainable. For now the helper class is used in UnfoldShortCircuitToIf and SeparateExpressionsReturningArrays. BUG=angleproject:1341 TEST=angle_end2end_tests, WebGL 2 conformance tests Change-Id: Ib1e0d5a84fd05bcca983b34f18d47c53e86dc227 Reviewed-on: https://chromium-review.googlesource.com/361693Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent b21e20da
...@@ -168,6 +168,8 @@ ...@@ -168,6 +168,8 @@
'compiler/translator/blocklayoutHLSL.h', 'compiler/translator/blocklayoutHLSL.h',
'compiler/translator/BuiltInFunctionEmulatorHLSL.cpp', 'compiler/translator/BuiltInFunctionEmulatorHLSL.cpp',
'compiler/translator/BuiltInFunctionEmulatorHLSL.h', 'compiler/translator/BuiltInFunctionEmulatorHLSL.h',
'compiler/translator/IntermNodePatternMatcher.cpp',
'compiler/translator/IntermNodePatternMatcher.h',
'compiler/translator/OutputHLSL.cpp', 'compiler/translator/OutputHLSL.cpp',
'compiler/translator/OutputHLSL.h', 'compiler/translator/OutputHLSL.h',
'compiler/translator/RemoveDynamicIndexing.cpp', 'compiler/translator/RemoveDynamicIndexing.cpp',
......
//
// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// IntermNodePatternMatcher is a helper class for matching node trees to given patterns.
// It can be used whenever the same checks for certain node structures are common to multiple AST
// traversers.
//
#include "compiler/translator/IntermNodePatternMatcher.h"
#include "compiler/translator/IntermNode.h"
namespace
{
bool IsNodeBlock(TIntermNode *node)
{
ASSERT(node != nullptr);
return (node->getAsAggregate() && node->getAsAggregate()->getOp() == EOpSequence);
}
} // anonymous namespace
IntermNodePatternMatcher::IntermNodePatternMatcher(const unsigned int mask) : mMask(mask)
{
}
bool IntermNodePatternMatcher::match(TIntermBinary *node, TIntermNode *parentNode)
{
if ((mMask & kExpressionReturningArray) != 0)
{
if (node->isArray() && node->getOp() == EOpAssign && parentNode != nullptr &&
!IsNodeBlock(parentNode))
{
return true;
}
}
if ((mMask & kUnfoldedShortCircuitExpression) != 0)
{
if (node->getRight()->hasSideEffects() &&
(node->getOp() == EOpLogicalOr || node->getOp() == EOpLogicalAnd))
{
return true;
}
}
return false;
}
bool IntermNodePatternMatcher::match(TIntermAggregate *node, TIntermNode *parentNode)
{
if ((mMask & kExpressionReturningArray) != 0)
{
if (parentNode != nullptr)
{
TIntermBinary *parentBinary = parentNode->getAsBinaryNode();
bool parentIsAssignment =
(parentBinary != nullptr &&
(parentBinary->getOp() == EOpAssign || parentBinary->getOp() == EOpInitialize));
if (node->getType().isArray() && !parentIsAssignment &&
(node->isConstructor() || node->getOp() == EOpFunctionCall) &&
!IsNodeBlock(parentNode))
{
return true;
}
}
}
return false;
}
bool IntermNodePatternMatcher::match(TIntermSelection *node)
{
if ((mMask & kUnfoldedShortCircuitExpression) != 0)
{
if (node->usesTernaryOperator())
{
return true;
}
}
return false;
}
//
// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// IntermNodePatternMatcher is a helper class for matching node trees to given patterns.
// It can be used whenever the same checks for certain node structures are common to multiple AST
// traversers.
//
#ifndef COMPILER_TRANSLATOR_INTERMNODEPATTERNMATCHER_H_
#define COMPILER_TRANSLATOR_INTERMNODEPATTERNMATCHER_H_
class TIntermAggregate;
class TIntermBinary;
class TIntermNode;
class TIntermSelection;
class IntermNodePatternMatcher
{
public:
enum PatternType
{
kUnfoldedShortCircuitExpression = 0x0001,
// Matches expressions that return arrays with the exception of simple statements where a
// constructor or function call result is assigned.
kExpressionReturningArray = 0x0002
};
IntermNodePatternMatcher(const unsigned int mask);
bool match(TIntermBinary *node, TIntermNode *parentNode);
bool match(TIntermAggregate *node, TIntermNode *parentNode);
bool match(TIntermSelection *node);
private:
const unsigned int mMask;
};
#endif
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "compiler/translator/SeparateExpressionsReturningArrays.h" #include "compiler/translator/SeparateExpressionsReturningArrays.h"
#include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNode.h"
#include "compiler/translator/IntermNodePatternMatcher.h"
namespace namespace
{ {
...@@ -32,11 +33,14 @@ class SeparateExpressionsTraverser : public TIntermTraverser ...@@ -32,11 +33,14 @@ class SeparateExpressionsTraverser : public TIntermTraverser
// Marked to true once an operation that needs to be hoisted out of the expression has been found. // Marked to true once an operation that needs to be hoisted out of the expression has been found.
// After that, no more AST updates are performed on that traversal. // After that, no more AST updates are performed on that traversal.
bool mFoundArrayExpression; bool mFoundArrayExpression;
IntermNodePatternMatcher mPatternToSeparateMatcher;
}; };
SeparateExpressionsTraverser::SeparateExpressionsTraverser() SeparateExpressionsTraverser::SeparateExpressionsTraverser()
: TIntermTraverser(true, false, false), : TIntermTraverser(true, false, false),
mFoundArrayExpression(false) mFoundArrayExpression(false),
mPatternToSeparateMatcher(IntermNodePatternMatcher::kExpressionReturningArray)
{ {
} }
...@@ -73,31 +77,27 @@ bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -73,31 +77,27 @@ bool SeparateExpressionsTraverser::visitBinary(Visit visit, TIntermBinary *node)
if (mFoundArrayExpression) if (mFoundArrayExpression)
return false; return false;
// Early return if the expression is not an array or if we're not inside a complex expression. // Return if the expression is not an array or if we're not inside a complex expression.
if (!node->getType().isArray() || parentNodeIsBlock()) if (!mPatternToSeparateMatcher.match(node, getParentNode()))
return true; return true;
switch (node->getOp()) ASSERT(node->getOp() == EOpAssign);
{
case EOpAssign: mFoundArrayExpression = true;
{
mFoundArrayExpression = true; TIntermSequence insertions;
insertions.push_back(CopyAssignmentNode(node));
TIntermSequence insertions; // TODO(oetuaho): In some cases it would be more optimal to not add the temporary node, but just
insertions.push_back(CopyAssignmentNode(node)); // use the original target of the assignment. Care must be taken so that this doesn't happen
// TODO(oetuaho): In some cases it would be more optimal to not add the temporary node, but just use the // when the same array symbol is a target of assignment more than once in one expression.
// original target of the assignment. Care must be taken so that this doesn't happen when the same array insertions.push_back(createTempInitDeclaration(node->getLeft()));
// symbol is a target of assignment more than once in one expression. insertStatementsInParentBlock(insertions);
insertions.push_back(createTempInitDeclaration(node->getLeft()));
insertStatementsInParentBlock(insertions); NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(node->getType()),
false);
NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(node->getType()), false); mReplacements.push_back(replaceVariable);
mReplacements.push_back(replaceVariable);
} return false;
return false;
default:
return true;
}
} }
bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
...@@ -105,43 +105,22 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate ...@@ -105,43 +105,22 @@ bool SeparateExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate
if (mFoundArrayExpression) if (mFoundArrayExpression)
return false; // No need to traverse further return false; // No need to traverse further
if (getParentNode() != nullptr) if (!mPatternToSeparateMatcher.match(node, getParentNode()))
{ return true;
TIntermBinary *parentBinary = getParentNode()->getAsBinaryNode();
bool parentIsAssignment = (parentBinary != nullptr &&
(parentBinary->getOp() == EOpAssign || parentBinary->getOp() == EOpInitialize));
if (!node->getType().isArray() || parentNodeIsBlock() || parentIsAssignment)
return true;
if (node->isConstructor())
{
mFoundArrayExpression = true;
TIntermSequence insertions;
insertions.push_back(createTempInitDeclaration(CopyAggregateNode(node)));
insertStatementsInParentBlock(insertions);
NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(node->getType()), false); ASSERT(node->isConstructor() || node->getOp() == EOpFunctionCall);
mReplacements.push_back(replaceVariable);
return false; mFoundArrayExpression = true;
}
else if (node->getOp() == EOpFunctionCall)
{
mFoundArrayExpression = true;
TIntermSequence insertions; TIntermSequence insertions;
insertions.push_back(createTempInitDeclaration(CopyAggregateNode(node))); insertions.push_back(createTempInitDeclaration(CopyAggregateNode(node)));
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(node->getType()), false); NodeUpdateEntry replaceVariable(getParentNode(), node, createTempSymbol(node->getType()),
mReplacements.push_back(replaceVariable); false);
mReplacements.push_back(replaceVariable);
return false; return false;
}
}
return true;
} }
void SeparateExpressionsTraverser::nextIteration() void SeparateExpressionsTraverser::nextIteration()
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "compiler/translator/UnfoldShortCircuitToIf.h" #include "compiler/translator/UnfoldShortCircuitToIf.h"
#include "compiler/translator/IntermNode.h" #include "compiler/translator/IntermNode.h"
#include "compiler/translator/IntermNodePatternMatcher.h"
namespace namespace
{ {
...@@ -48,6 +49,8 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser ...@@ -48,6 +49,8 @@ class UnfoldShortCircuitTraverser : public TIntermTraverser
bool mInLoopCondition; bool mInLoopCondition;
bool mInLoopExpression; bool mInLoopExpression;
IntermNodePatternMatcher mPatternToUnfoldMatcher;
}; };
UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser() UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
...@@ -56,7 +59,8 @@ UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser() ...@@ -56,7 +59,8 @@ UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
mParentLoop(nullptr), mParentLoop(nullptr),
mLoopParent(nullptr), mLoopParent(nullptr),
mInLoopCondition(false), mInLoopCondition(false),
mInLoopExpression(false) mInLoopExpression(false),
mPatternToUnfoldMatcher(IntermNodePatternMatcher::kUnfoldedShortCircuitExpression)
{ {
} }
...@@ -64,18 +68,23 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -64,18 +68,23 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
if (mFoundShortCircuit) if (mFoundShortCircuit)
return false; return false;
if (visit != PreVisit)
return true;
if (!mPatternToUnfoldMatcher.match(node, getParentNode()))
return true;
// If our right node doesn't have side effects, we know we don't need to unfold this // If our right node doesn't have side effects, we know we don't need to unfold this
// expression: there will be no short-circuiting side effects to avoid // expression: there will be no short-circuiting side effects to avoid
// (note: unfolding doesn't depend on the left node -- it will always be evaluated) // (note: unfolding doesn't depend on the left node -- it will always be evaluated)
if (!node->getRight()->hasSideEffects()) ASSERT(node->getRight()->hasSideEffects());
{
return true; mFoundShortCircuit = true;
}
switch (node->getOp()) switch (node->getOp())
{ {
case EOpLogicalOr: case EOpLogicalOr:
mFoundShortCircuit = true;
if (!copyLoopConditionOrExpression(getParentNode(), node)) if (!copyLoopConditionOrExpression(getParentNode(), node))
{ {
// "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true; // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true;
...@@ -104,7 +113,6 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -104,7 +113,6 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
} }
return false; return false;
case EOpLogicalAnd: case EOpLogicalAnd:
mFoundShortCircuit = true;
if (!copyLoopConditionOrExpression(getParentNode(), node)) if (!copyLoopConditionOrExpression(getParentNode(), node))
{ {
// "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y; // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y;
...@@ -130,7 +138,8 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -130,7 +138,8 @@ bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
} }
return false; return false;
default: default:
return true; UNREACHABLE();
return true;
} }
} }
...@@ -139,43 +148,46 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection * ...@@ -139,43 +148,46 @@ bool UnfoldShortCircuitTraverser::visitSelection(Visit visit, TIntermSelection *
if (mFoundShortCircuit) if (mFoundShortCircuit)
return false; return false;
if (visit != PreVisit)
return true;
if (!mPatternToUnfoldMatcher.match(node))
return true;
mFoundShortCircuit = true;
ASSERT(node->usesTernaryOperator());
// Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;" // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
if (visit == PreVisit && node->usesTernaryOperator()) if (!copyLoopConditionOrExpression(getParentNode(), node))
{ {
mFoundShortCircuit = true; TIntermSequence insertions;
if (!copyLoopConditionOrExpression(getParentNode(), node))
{
TIntermSequence insertions;
TIntermSymbol *tempSymbol = createTempSymbol(node->getType()); TIntermSymbol *tempSymbol = createTempSymbol(node->getType());
TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration); TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
tempDeclaration->getSequence()->push_back(tempSymbol); tempDeclaration->getSequence()->push_back(tempSymbol);
insertions.push_back(tempDeclaration); insertions.push_back(tempDeclaration);
TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *trueAssignment = TIntermBinary *trueAssignment = createTempAssignment(node->getTrueBlock()->getAsTyped());
createTempAssignment(node->getTrueBlock()->getAsTyped()); trueBlock->getSequence()->push_back(trueAssignment);
trueBlock->getSequence()->push_back(trueAssignment);
TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence); TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
TIntermBinary *falseAssignment = TIntermBinary *falseAssignment = createTempAssignment(node->getFalseBlock()->getAsTyped());
createTempAssignment(node->getFalseBlock()->getAsTyped()); falseBlock->getSequence()->push_back(falseAssignment);
falseBlock->getSequence()->push_back(falseAssignment);
TIntermSelection *ifNode = TIntermSelection *ifNode =
new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock); new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
insertions.push_back(ifNode); insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions); insertStatementsInParentBlock(insertions);
TIntermSymbol *ternaryResult = createTempSymbol(node->getType()); TIntermSymbol *ternaryResult = createTempSymbol(node->getType());
NodeUpdateEntry replaceVariable(getParentNode(), node, ternaryResult, false); NodeUpdateEntry replaceVariable(getParentNode(), node, ternaryResult, false);
mReplacements.push_back(replaceVariable); mReplacements.push_back(replaceVariable);
}
return false;
} }
return true; return false;
} }
bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool UnfoldShortCircuitTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment