Commit 477b243b by Corentin Wallez

Change the FLATTEN heuristic to "ifs with a loop with a gradient"

This heuristic makes more sense than the previous "ifs with a discontinuous loop" as the reason we need to flatten is that we need gradients to be in branchless code. Change the UnrollFlatten test accordingly. Tested with: - the WebGL CTS - dev.miaumiau.cat/rayTracer "Skull Demo" - THe turbulenz engine GPU particle demo - Lots of ShaderToy Samples (inc. Volcanic, Metropolis and Hierarchical Voronoi) - Google Maps Earth mode - Lots of Chrome experiments - madebyevan.com/webgl-water BUG=524297 Change-Id: Iaa727036fffcfde3952716a1ef33b6ee0546b69d Reviewed-on: https://chromium-review.googlesource.com/296442Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Tested-by: 's avatarCorentin Wallez <cwallez@chromium.org>
parent 12d59314
...@@ -139,13 +139,16 @@ class PullGradient : public TIntermTraverser ...@@ -139,13 +139,16 @@ class PullGradient : public TIntermTraverser
std::vector<TIntermNode*> mParents; std::vector<TIntermNode*> mParents;
}; };
// Traverses the AST of a function definition, assuming it has already been used to // Traverses the AST of a function definition to compute the the discontinuous loops
// traverse the callees of that function; computes the discontinuous loops and the if // and the if statements containing gradient loops. It assumes that the gradient loops
// statements that contain a discontinuous loop in their call graph. // (loops that contain a gradient) have already been computed and that it has already
class PullComputeDiscontinuousLoops : public TIntermTraverser // traversed the current function's callees.
class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
{ {
public: public:
PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag) PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
size_t index,
const CallDAG &dag)
: TIntermTraverser(true, false, true), : TIntermTraverser(true, false, true),
mMetadataList(metadataList), mMetadataList(metadataList),
mMetadata(&(*metadataList)[index]), mMetadata(&(*metadataList)[index]),
...@@ -161,15 +164,15 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -161,15 +164,15 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
ASSERT(mIfs.empty()); ASSERT(mIfs.empty());
} }
// Called when a discontinuous loop or a call to a function with a discontinuous loop // Called when traversing a gradient loop or a call to a function with a
// in its call graph is found. // gradient loop in its call graph.
void onDiscontinuousLoop() void onGradientLoop()
{ {
mMetadata->mHasDiscontinuousLoopInCallGraph = true; mMetadata->mHasGradientLoopInCallGraph = true;
// Mark the latest if as using a discontinuous loop. // Mark the latest if as using a discontinuous loop.
if (!mIfs.empty()) if (!mIfs.empty())
{ {
mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back()); mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
} }
} }
...@@ -178,6 +181,11 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -178,6 +181,11 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
if (visit == PreVisit) if (visit == PreVisit)
{ {
mLoopsAndSwitches.push_back(loop); mLoopsAndSwitches.push_back(loop);
if (mMetadata->hasGradientInCallGraph(loop))
{
onGradientLoop();
}
} }
else if (visit == PostVisit) else if (visit == PostVisit)
{ {
...@@ -199,9 +207,9 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -199,9 +207,9 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
ASSERT(mIfs.back() == node); ASSERT(mIfs.back() == node);
mIfs.pop_back(); mIfs.pop_back();
// An if using a discontinuous loop means its parents ifs are also discontinuous. // An if using a discontinuous loop means its parents ifs are also discontinuous.
if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty()) if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
{ {
mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back()); mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
} }
} }
...@@ -221,7 +229,6 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -221,7 +229,6 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
if (loop != nullptr) if (loop != nullptr)
{ {
mMetadata->mDiscontinuousLoops.insert(loop); mMetadata->mDiscontinuousLoops.insert(loop);
onDiscontinuousLoop();
} }
} }
break; break;
...@@ -237,7 +244,6 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -237,7 +244,6 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
} }
ASSERT(loop != nullptr); ASSERT(loop != nullptr);
mMetadata->mDiscontinuousLoops.insert(loop); mMetadata->mDiscontinuousLoops.insert(loop);
onDiscontinuousLoop();
} }
break; break;
case EOpKill: case EOpKill:
...@@ -253,7 +259,6 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -253,7 +259,6 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
mMetadata->mDiscontinuousLoops.insert(loop); mMetadata->mDiscontinuousLoops.insert(loop);
} }
} }
onDiscontinuousLoop();
} }
break; break;
default: default:
...@@ -274,9 +279,9 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser ...@@ -274,9 +279,9 @@ class PullComputeDiscontinuousLoops : public TIntermTraverser
ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex); ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
UNUSED_ASSERTION_VARIABLE(mIndex); UNUSED_ASSERTION_VARIABLE(mIndex);
if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph) if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
{ {
onDiscontinuousLoop(); onGradientLoop();
} }
} }
} }
...@@ -375,19 +380,14 @@ class PushDiscontinuousLoops : public TIntermTraverser ...@@ -375,19 +380,14 @@ class PushDiscontinuousLoops : public TIntermTraverser
} }
bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node)
{
return mControlFlowsContainingGradient.count(node) > 0;
}
bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node) bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
{ {
return mControlFlowsContainingGradient.count(node) > 0; return mControlFlowsContainingGradient.count(node) > 0;
} }
bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node) bool ASTMetadataHLSL::hasGradientLoop(TIntermSelection *node)
{ {
return mIfsContainingDiscontinuousLoop.count(node) > 0; return mIfsContainingGradientLoop.count(node) > 0;
} }
MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag) MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
...@@ -424,10 +424,10 @@ MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag) ...@@ -424,10 +424,10 @@ MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
// of callgraph analysis as for the gradient. // of callgraph analysis as for the gradient.
// First compute which loops are discontinuous (no specific order) and pull // First compute which loops are discontinuous (no specific order) and pull
// the ifs and functions using a discontinuous loop. // the ifs and functions using a gradient loop.
for (size_t i = 0; i < callDag.size(); i++) for (size_t i = 0; i < callDag.size(); i++)
{ {
PullComputeDiscontinuousLoops pull(&metadataList, i, callDag); PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
pull.traverse(callDag.getRecordFromIndex(i).node); pull.traverse(callDag.getRecordFromIndex(i).node);
} }
......
...@@ -22,16 +22,15 @@ struct ASTMetadataHLSL ...@@ -22,16 +22,15 @@ struct ASTMetadataHLSL
ASTMetadataHLSL() ASTMetadataHLSL()
: mUsesGradient(false), : mUsesGradient(false),
mCalledInDiscontinuousLoop(false), mCalledInDiscontinuousLoop(false),
mHasDiscontinuousLoopInCallGraph(false), mHasGradientLoopInCallGraph(false),
mNeedsLod0(false) mNeedsLod0(false)
{ {
} }
// Here "something uses a gradient" means here that it either contains a // Here "something uses a gradient" means here that it either contains a
// gradient operation, or a call to a function that uses a gradient. // gradient operation, or a call to a function that uses a gradient.
bool hasGradientInCallGraph(TIntermSelection *node);
bool hasGradientInCallGraph(TIntermLoop *node); bool hasGradientInCallGraph(TIntermLoop *node);
bool hasDiscontinuousLoop(TIntermSelection *node); bool hasGradientLoop(TIntermSelection *node);
// Does the function use a gradient. // Does the function use a gradient.
bool mUsesGradient; bool mUsesGradient;
...@@ -43,9 +42,9 @@ struct ASTMetadataHLSL ...@@ -43,9 +42,9 @@ struct ASTMetadataHLSL
// Remember information about the discontinuous loops and which functions // Remember information about the discontinuous loops and which functions
// are called in such loops. // are called in such loops.
bool mCalledInDiscontinuousLoop; bool mCalledInDiscontinuousLoop;
bool mHasDiscontinuousLoopInCallGraph; bool mHasGradientLoopInCallGraph;
std::set<TIntermLoop*> mDiscontinuousLoops; std::set<TIntermLoop*> mDiscontinuousLoops;
std::set<TIntermSelection*> mIfsContainingDiscontinuousLoop; std::set<TIntermSelection *> mIfsContainingGradientLoop;
// Will we need to generate a Lod0 version of the function. // Will we need to generate a Lod0 version of the function.
bool mNeedsLod0; bool mNeedsLod0;
......
...@@ -2374,9 +2374,7 @@ bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node) ...@@ -2374,9 +2374,7 @@ bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
} }
// D3D errors when there is a gradient operation in a loop in an unflattened if. // D3D errors when there is a gradient operation in a loop in an unflattened if.
if (mShaderType == GL_FRAGMENT_SHADER && if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
mCurrentFunctionMetadata->hasDiscontinuousLoop(node) &&
mCurrentFunctionMetadata->hasGradientInCallGraph(node))
{ {
out << "FLATTEN "; out << "FLATTEN ";
} }
......
...@@ -105,7 +105,7 @@ TEST_F(UnrollFlattenTest, NoGradient) ...@@ -105,7 +105,7 @@ TEST_F(UnrollFlattenTest, NoGradient)
// 2 - no FLATTEN because does not contain discont loop // 2 - no FLATTEN because does not contain discont loop
// 3 - shouldn't get a Lod0 version generated // 3 - shouldn't get a Lod0 version generated
// 4 - no LOOP because discont, and also no gradient // 4 - no LOOP because discont, and also no gradient
// 5 - no FLATTEN because does not contain discont loop // 5 - no FLATTEN because does not contain loop with a gradient
// 6 - call non-Lod0 version // 6 - call non-Lod0 version
// 7 - no FLATTEN // 7 - no FLATTEN
const char *expectations[] = const char *expectations[] =
...@@ -146,16 +146,16 @@ TEST_F(UnrollFlattenTest, GradientNotInDiscont) ...@@ -146,16 +146,16 @@ TEST_F(UnrollFlattenTest, GradientNotInDiscont)
// 2 - no Lod0 version generated // 2 - no Lod0 version generated
// 3 - shouldn't get a Lod0 version generated (not in discont loop) // 3 - shouldn't get a Lod0 version generated (not in discont loop)
// 4 - should have LOOP because it contains a gradient operation (even if Lod0) // 4 - should have LOOP because it contains a gradient operation (even if Lod0)
// 5 - no FLATTEN because doesn't contain discont loop // 5 - no FLATTEN because doesn't contain loop with a gradient
// 6 - call non-Lod0 version // 6 - call non-Lod0 version
// 7 - call non-Lod0 version // 7 - call non-Lod0 version
// 8 - no FLATTEN // 8 - FLATTEN because it contains a loop with a gradient
compile(shaderString); compile(shaderString);
const char *expectations[] = const char *expectations[] =
{ {
"fun(", "texture2D(", "fun(", "texture2D(",
"fun2(", "LOOP", "for", "if", "fun(", "texture2D(", "fun2(", "LOOP", "for", "if", "fun(", "texture2D(",
"main(", "if", "fun2(" "main(", "FLATTEN", "if", "fun2("
}; };
expect(expectations, ArraySize(expectations)); expect(expectations, ArraySize(expectations));
} }
...@@ -188,10 +188,10 @@ TEST_F(UnrollFlattenTest, GradientInDiscont) ...@@ -188,10 +188,10 @@ TEST_F(UnrollFlattenTest, GradientInDiscont)
// 2 - will get the Lod0 if in funLod0 // 2 - will get the Lod0 if in funLod0
// 3 - shouldn't get a Lod0 version generated (not in discont loop) // 3 - shouldn't get a Lod0 version generated (not in discont loop)
// 4 - should have LOOP because it contains a gradient operation (even if Lod0) // 4 - should have LOOP because it contains a gradient operation (even if Lod0)
// 5 - no FLATTEN because doesn't contain discont loop // 5 - no FLATTEN because doesn't contain a loop with a gradient
// 6 - call Lod0 version // 6 - call Lod0 version
// 7 - call Lod0 version // 7 - call Lod0 version
// 8 - should have a FLATTEN because has a discont loop and gradient // 8 - FLATTEN because it contains a loop with a gradient
compile(shaderString); compile(shaderString);
const char *expectations[] = const char *expectations[] =
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment