Compiler - split header, body and footer output

TRAC #11798 Signed-off-by: Shannon Woods Signed-off-by: Daniel Koch Author: Nicolas Capens git-svn-id: https://angleproject.googlecode.com/svn/trunk@126 736b8ea6-26fd-11df-bfd4-992fa37f6226
parent 73c2c2ed
...@@ -11,14 +11,25 @@ ...@@ -11,14 +11,25 @@
namespace sh namespace sh
{ {
OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), context(context) OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context)
{ {
} }
void OutputHLSL::output()
{
mContext.treeRoot->traverse(this); // Output the body first to determine what has to go in the header and footer
header();
footer();
mContext.infoSink.obj << mHeader.c_str();
mContext.infoSink.obj << mBody.c_str();
mContext.infoSink.obj << mFooter.c_str();
}
void OutputHLSL::header() void OutputHLSL::header()
{ {
EShLanguage language = context.language; EShLanguage language = mContext.language;
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mHeader;
if (language == EShLangFragment) if (language == EShLangFragment)
{ {
...@@ -26,7 +37,7 @@ void OutputHLSL::header() ...@@ -26,7 +37,7 @@ void OutputHLSL::header()
TString varyingInput; TString varyingInput;
TString varyingGlobals; TString varyingGlobals;
TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel(); TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
int semanticIndex = 0; int semanticIndex = 0;
for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++) for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
...@@ -119,7 +130,7 @@ void OutputHLSL::header() ...@@ -119,7 +130,7 @@ void OutputHLSL::header()
TString varyingOutput; TString varyingOutput;
TString varyingGlobals; TString varyingGlobals;
TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel(); TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
int semanticIndex = 0; int semanticIndex = 0;
for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++) for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
...@@ -416,9 +427,9 @@ void OutputHLSL::header() ...@@ -416,9 +427,9 @@ void OutputHLSL::header()
void OutputHLSL::footer() void OutputHLSL::footer()
{ {
EShLanguage language = context.language; EShLanguage language = mContext.language;
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mFooter;
TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel(); TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
if (language == EShLangFragment) if (language == EShLangFragment)
{ {
...@@ -489,7 +500,7 @@ void OutputHLSL::footer() ...@@ -489,7 +500,7 @@ void OutputHLSL::footer()
" output.gl_PointSize = gl_PointSize;\n" " output.gl_PointSize = gl_PointSize;\n"
" output.gl_FragCoord = gl_Position;\n"; " output.gl_FragCoord = gl_Position;\n";
TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel(); TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++) for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
{ {
...@@ -516,7 +527,7 @@ void OutputHLSL::footer() ...@@ -516,7 +527,7 @@ void OutputHLSL::footer()
void OutputHLSL::visitSymbol(TIntermSymbol *node) void OutputHLSL::visitSymbol(TIntermSymbol *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
TString name = node->getSymbol(); TString name = node->getSymbol();
...@@ -536,7 +547,7 @@ void OutputHLSL::visitSymbol(TIntermSymbol *node) ...@@ -536,7 +547,7 @@ void OutputHLSL::visitSymbol(TIntermSymbol *node)
bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
switch (node->getOp()) switch (node->getOp())
{ {
...@@ -711,7 +722,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -711,7 +722,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node) bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
switch (node->getOp()) switch (node->getOp())
{ {
...@@ -789,8 +800,8 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node) ...@@ -789,8 +800,8 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
EShLanguage language = context.language; EShLanguage language = mContext.language;
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
if (node->getOp() == EOpNull) if (node->getOp() == EOpNull)
{ {
...@@ -1057,7 +1068,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -1057,7 +1068,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node) bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
if (node->usesTernaryOperator()) if (node->usesTernaryOperator())
{ {
...@@ -1098,7 +1109,7 @@ bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node) ...@@ -1098,7 +1109,7 @@ bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node) void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
const TType &type = node->getType(); const TType &type = node->getType();
...@@ -1231,7 +1242,7 @@ bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node) ...@@ -1231,7 +1242,7 @@ bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
return false; return false;
} }
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
if (!node->testFirst()) if (!node->testFirst())
{ {
...@@ -1288,7 +1299,7 @@ bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node) ...@@ -1288,7 +1299,7 @@ bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node) bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
switch (node->getFlowOp()) switch (node->getFlowOp())
{ {
...@@ -1321,7 +1332,7 @@ bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node) ...@@ -1321,7 +1332,7 @@ bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
// Handle loops with more than 255 iterations (unsupported by D3D9) by splitting them // Handle loops with more than 255 iterations (unsupported by D3D9) by splitting them
bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node) bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
// Parse loops of the form: // Parse loops of the form:
// for(int index = initial; index [comparator] limit; index += increment) // for(int index = initial; index [comparator] limit; index += increment)
...@@ -1486,7 +1497,7 @@ bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node) ...@@ -1486,7 +1497,7 @@ bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString) void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString)
{ {
TInfoSinkBase &out = context.infoSink.obj; TInfoSinkBase &out = mBody;
if (visit == PreVisit && preString) if (visit == PreVisit && preString)
{ {
......
...@@ -17,10 +17,13 @@ class OutputHLSL : public TIntermTraverser ...@@ -17,10 +17,13 @@ class OutputHLSL : public TIntermTraverser
public: public:
OutputHLSL(TParseContext &context); OutputHLSL(TParseContext &context);
void output();
protected:
void header(); void header();
void footer(); void footer();
protected: // Visit AST nodes and output their code to the body stream
void visitSymbol(TIntermSymbol*); void visitSymbol(TIntermSymbol*);
void visitConstantUnion(TIntermConstantUnion*); void visitConstantUnion(TIntermConstantUnion*);
bool visitBinary(Visit visit, TIntermBinary*); bool visitBinary(Visit visit, TIntermBinary*);
...@@ -39,7 +42,12 @@ class OutputHLSL : public TIntermTraverser ...@@ -39,7 +42,12 @@ class OutputHLSL : public TIntermTraverser
static TString arrayString(const TType &type); static TString arrayString(const TType &type);
static TString initializer(const TType &type); static TString initializer(const TType &type);
TParseContext &context; TParseContext &mContext;
// Output streams
TInfoSinkBase mHeader;
TInfoSinkBase mBody;
TInfoSinkBase mFooter;
}; };
} }
......
...@@ -17,9 +17,7 @@ bool TranslatorHLSL::compile(TIntermNode *root) ...@@ -17,9 +17,7 @@ bool TranslatorHLSL::compile(TIntermNode *root)
TParseContext& parseContext = *GetGlobalParseContext(); TParseContext& parseContext = *GetGlobalParseContext();
sh::OutputHLSL outputHLSL(parseContext); sh::OutputHLSL outputHLSL(parseContext);
outputHLSL.header(); outputHLSL.output();
parseContext.treeRoot->traverse(&outputHLSL);
outputHLSL.footer();
return true; return true;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment