Commit 06235df9 by Olli Etuaho Committed by Commit Bot

Make HLSL shaders use only one main function

Instead of having separate main() and gl_main() functions in HLSL shaders, add initializing outputs and inputs directly to the main function that's in the AST. This works around some HLSL bugs and should not introduce name conflicts inside main() since all the user-defined variables are prefixed. BUG=angleproject:2325 TEST=angle_end2end_tests Change-Id: I5b000c96aac8f321cefe50b6a893008498eac0d5 Reviewed-on: https://chromium-review.googlesource.com/1146647Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 522095f7
...@@ -182,6 +182,7 @@ OutputHLSL::OutputHLSL(sh::GLenum shaderType, ...@@ -182,6 +182,7 @@ OutputHLSL::OutputHLSL(sh::GLenum shaderType,
int numRenderTargets, int numRenderTargets,
const std::vector<Uniform> &uniforms, const std::vector<Uniform> &uniforms,
ShCompileOptions compileOptions, ShCompileOptions compileOptions,
sh::WorkGroupSize workGroupSize,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
PerformanceDiagnostics *perfDiagnostics) PerformanceDiagnostics *perfDiagnostics)
: TIntermTraverser(true, true, true, symbolTable), : TIntermTraverser(true, true, true, symbolTable),
...@@ -191,12 +192,13 @@ OutputHLSL::OutputHLSL(sh::GLenum shaderType, ...@@ -191,12 +192,13 @@ OutputHLSL::OutputHLSL(sh::GLenum shaderType,
mSourcePath(sourcePath), mSourcePath(sourcePath),
mOutputType(outputType), mOutputType(outputType),
mCompileOptions(compileOptions), mCompileOptions(compileOptions),
mInsideFunction(false),
mInsideMain(false),
mNumRenderTargets(numRenderTargets), mNumRenderTargets(numRenderTargets),
mCurrentFunctionMetadata(nullptr), mCurrentFunctionMetadata(nullptr),
mWorkGroupSize(workGroupSize),
mPerfDiagnostics(perfDiagnostics) mPerfDiagnostics(perfDiagnostics)
{ {
mInsideFunction = false;
mUsesFragColor = false; mUsesFragColor = false;
mUsesFragData = false; mUsesFragData = false;
mUsesDepthRange = false; mUsesDepthRange = false;
...@@ -1743,10 +1745,16 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node) ...@@ -1743,10 +1745,16 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
if (mInsideFunction) if (mInsideFunction)
{ {
outputLineDirective(out, node->getLine().first_line); outputLineDirective(out, node->getLine().first_line);
out << "{\n"; out << "{\n";
if (isMainBlock)
{
out << "@@ MAIN PROLOGUE @@\n";
}
} }
for (TIntermNode *statement : *node->getSequence()) for (TIntermNode *statement : *node->getSequence())
...@@ -1781,6 +1789,19 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node) ...@@ -1781,6 +1789,19 @@ bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
if (mInsideFunction) if (mInsideFunction)
{ {
outputLineDirective(out, node->getLine().last_line); outputLineDirective(out, node->getLine().last_line);
if (isMainBlock && shaderNeedsGenerateOutput())
{
// We could have an empty main, a main function without a branch at the end, or a main
// function with a discard statement at the end. In these cases we need to add a return
// statement.
bool needReturnStatement =
node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
if (needReturnStatement)
{
out << "return " << generateOutputCall() << ";\n";
}
}
out << "}\n"; out << "}\n";
} }
...@@ -1797,40 +1818,64 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition ...@@ -1797,40 +1818,64 @@ bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition
ASSERT(index != CallDAG::InvalidIndex); ASSERT(index != CallDAG::InvalidIndex);
mCurrentFunctionMetadata = &mASTMetadataList[index]; mCurrentFunctionMetadata = &mASTMetadataList[index];
out << TypeString(node->getFunctionPrototype()->getType()) << " ";
const TFunction *func = node->getFunction(); const TFunction *func = node->getFunction();
if (func->isMain()) if (func->isMain())
{ {
out << "gl_main("; // The stub strings below are replaced when shader is dynamically defined by its layout:
switch (mShaderType)
{
case GL_VERTEX_SHADER:
out << "@@ VERTEX ATTRIBUTES @@\n\n"
<< "@@ VERTEX OUTPUT @@\n\n"
<< "VS_OUTPUT main(VS_INPUT input)";
break;
case GL_FRAGMENT_SHADER:
out << "@@ PIXEL OUTPUT @@\n\n"
<< "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
break;
case GL_COMPUTE_SHADER:
out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
<< mWorkGroupSize[2] << ")]\n";
out << "void main(CS_INPUT input)";
break;
default:
UNREACHABLE();
break;
}
} }
else else
{ {
out << TypeString(node->getFunctionPrototype()->getType()) << " ";
out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func) out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
<< (mOutputLod0Function ? "Lod0(" : "("); << (mOutputLod0Function ? "Lod0(" : "(");
}
size_t paramCount = func->getParamCount(); size_t paramCount = func->getParamCount();
for (unsigned int i = 0; i < paramCount; i++) for (unsigned int i = 0; i < paramCount; i++)
{ {
const TVariable *param = func->getParam(i); const TVariable *param = func->getParam(i);
ensureStructDefined(param->getType()); ensureStructDefined(param->getType());
writeParameter(param, out); writeParameter(param, out);
if (i < paramCount - 1) if (i < paramCount - 1)
{ {
out << ", "; out << ", ";
}
} }
}
out << ")\n"; out << ")\n";
}
mInsideFunction = true; mInsideFunction = true;
if (func->isMain())
{
mInsideMain = true;
}
// The function body node will output braces. // The function body node will output braces.
node->getBody()->traverse(this); node->getBody()->traverse(this);
mInsideFunction = false; mInsideFunction = false;
mInsideMain = false;
mCurrentFunctionMetadata = nullptr; mCurrentFunctionMetadata = nullptr;
...@@ -2455,11 +2500,19 @@ bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node) ...@@ -2455,11 +2500,19 @@ bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
case EOpReturn: case EOpReturn:
if (node->getExpression()) if (node->getExpression())
{ {
ASSERT(!mInsideMain);
out << "return "; out << "return ";
} }
else else
{ {
out << "return"; if (mInsideMain && shaderNeedsGenerateOutput())
{
out << "return " << generateOutputCall();
}
else
{
out << "return";
}
} }
break; break;
default: default:
...@@ -3154,4 +3207,21 @@ void OutputHLSL::ensureStructDefined(const TType &type) ...@@ -3154,4 +3207,21 @@ void OutputHLSL::ensureStructDefined(const TType &type)
} }
} }
bool OutputHLSL::shaderNeedsGenerateOutput() const
{
return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
}
const char *OutputHLSL::generateOutputCall() const
{
if (mShaderType == GL_VERTEX_SHADER)
{
return "generateOutput(input)";
}
else
{
return "generateOutput()";
}
}
} // namespace sh } // namespace sh
...@@ -53,6 +53,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -53,6 +53,7 @@ class OutputHLSL : public TIntermTraverser
int numRenderTargets, int numRenderTargets,
const std::vector<Uniform> &uniforms, const std::vector<Uniform> &uniforms,
ShCompileOptions compileOptions, ShCompileOptions compileOptions,
sh::WorkGroupSize workGroupSize,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
PerformanceDiagnostics *perfDiagnostics); PerformanceDiagnostics *perfDiagnostics);
...@@ -146,6 +147,9 @@ class OutputHLSL : public TIntermTraverser ...@@ -146,6 +147,9 @@ class OutputHLSL : public TIntermTraverser
// Ensures if the type is a struct, the struct is defined // Ensures if the type is a struct, the struct is defined
void ensureStructDefined(const TType &type); void ensureStructDefined(const TType &type);
bool shaderNeedsGenerateOutput() const;
const char *generateOutputCall() const;
sh::GLenum mShaderType; sh::GLenum mShaderType;
int mShaderVersion; int mShaderVersion;
const TExtensionBehavior &mExtensionBehavior; const TExtensionBehavior &mExtensionBehavior;
...@@ -154,6 +158,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -154,6 +158,7 @@ class OutputHLSL : public TIntermTraverser
ShCompileOptions mCompileOptions; ShCompileOptions mCompileOptions;
bool mInsideFunction; bool mInsideFunction;
bool mInsideMain;
// Output streams // Output streams
TInfoSinkBase mHeader; TInfoSinkBase mHeader;
...@@ -250,6 +255,8 @@ class OutputHLSL : public TIntermTraverser ...@@ -250,6 +255,8 @@ class OutputHLSL : public TIntermTraverser
// arrays can't be return values in HLSL. // arrays can't be return values in HLSL.
std::vector<ArrayHelperFunction> mArrayConstructIntoFunctions; std::vector<ArrayHelperFunction> mArrayConstructIntoFunctions;
sh::WorkGroupSize mWorkGroupSize;
PerformanceDiagnostics *mPerfDiagnostics; PerformanceDiagnostics *mPerfDiagnostics;
private: private:
......
...@@ -128,7 +128,8 @@ void TranslatorHLSL::translate(TIntermBlock *root, ...@@ -128,7 +128,8 @@ void TranslatorHLSL::translate(TIntermBlock *root,
sh::OutputHLSL outputHLSL(getShaderType(), getShaderVersion(), getExtensionBehavior(), sh::OutputHLSL outputHLSL(getShaderType(), getShaderVersion(), getExtensionBehavior(),
getSourcePath(), getOutputType(), numRenderTargets, getUniforms(), getSourcePath(), getOutputType(), numRenderTargets, getUniforms(),
compileOptions, &getSymbolTable(), perfDiagnostics); compileOptions, getComputeShaderLocalSize(), &getSymbolTable(),
perfDiagnostics);
outputHLSL.output(root, getInfoSink().obj); outputHLSL.output(root, getInfoSink().obj);
......
...@@ -4615,6 +4615,40 @@ TEST_P(GLSLTest_ES3, AssignAssignmentToSwizzled) ...@@ -4615,6 +4615,40 @@ TEST_P(GLSLTest_ES3, AssignAssignmentToSwizzled)
EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::white); EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::white);
} }
// Test a fragment shader that returns inside if (that being the only branch that actually gets
// executed). Regression test for http://anglebug.com/2325
TEST_P(GLSLTest, IfElseIfAndReturn)
{
const std::string &vertexShader =
R"(attribute vec4 a_position;
varying vec2 vPos;
void main()
{
gl_Position = a_position;
vPos = a_position.xy;
})";
const std::string &fragmentShader =
R"(precision mediump float;
varying vec2 vPos;
void main()
{
if (vPos.x < 1.0) // This colors the whole canvas green
{
gl_FragColor = vec4(0, 1, 0, 1);
return;
}
else if (vPos.x < 1.1) // This should have no effect
{
gl_FragColor = vec4(1, 0, 0, 1);
}
})";
ANGLE_GL_PROGRAM(program, vertexShader, fragmentShader);
drawQuad(program.get(), "a_position", 0.5f);
EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
}
// Use this to select which configurations (e.g. which renderer, which GLES major version) these // Use this to select which configurations (e.g. which renderer, which GLES major version) these
// tests should be run against. // tests should be run against.
ANGLE_INSTANTIATE_TEST(GLSLTest, ANGLE_INSTANTIATE_TEST(GLSLTest,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment