Implemented struct equality

TRAC #11727 Signed-off-by: Shannon Woods Signed-off-by: Daniel Koch Author: Nicolas Capens git-svn-id: https://angleproject.googlecode.com/svn/trunk@127 736b8ea6-26fd-11df-bfd4-992fa37f6226
parent 950f993b
...@@ -13,6 +13,18 @@ namespace sh ...@@ -13,6 +13,18 @@ namespace sh
{ {
OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context) OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context)
{ {
mUsesEqualMat2 = false;
mUsesEqualMat3 = false;
mUsesEqualMat4 = false;
mUsesEqualVec2 = false;
mUsesEqualVec3 = false;
mUsesEqualVec4 = false;
mUsesEqualIVec2 = false;
mUsesEqualIVec3 = false;
mUsesEqualIVec4 = false;
mUsesEqualBVec2 = false;
mUsesEqualBVec3 = false;
mUsesEqualBVec4 = false;
} }
void OutputHLSL::output() void OutputHLSL::output()
...@@ -401,28 +413,109 @@ void OutputHLSL::header() ...@@ -401,28 +413,109 @@ void OutputHLSL::header()
" return -N;\n" " return -N;\n"
" }\n" " }\n"
"}\n" "}\n"
"\n"
"bool __equal(float2x2 m, float2x2 n)\n"
"{\n"
" return m[0][0] == n[0][0] && m[0][1] == n[0][1] &&\n"
" m[1][0] == n[1][0] && m[1][1] == n[1][1];\n"
"}\n"
"\n"
"bool __equal(float3x3 m, float3x3 n)\n"
"{\n"
" return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] &&\n"
" m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] &&\n"
" m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2];\n"
"}\n"
"\n"
"bool __equal(float4x4 m, float4x4 n)\n"
"{\n"
" return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] && m[0][3] == n[0][3] &&\n"
" m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] && m[1][3] == n[1][3] &&\n"
" m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2] && m[2][3] == n[2][3] &&\n"
" m[3][0] == n[3][0] && m[3][1] == n[3][1] && m[3][2] == n[3][2] && m[3][3] == n[3][3];\n"
"}\n"
"\n"; "\n";
if (mUsesEqualMat2)
{
out << "bool __equal(float2x2 m, float2x2 n)\n"
"{\n"
" return m[0][0] == n[0][0] && m[0][1] == n[0][1] &&\n"
" m[1][0] == n[1][0] && m[1][1] == n[1][1];\n"
"}\n";
}
if (mUsesEqualMat3)
{
out << "bool __equal(float3x3 m, float3x3 n)\n"
"{\n"
" return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] &&\n"
" m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] &&\n"
" m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2];\n"
"}\n";
}
if (mUsesEqualMat4)
{
out << "bool __equal(float4x4 m, float4x4 n)\n"
"{\n"
" return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] && m[0][3] == n[0][3] &&\n"
" m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] && m[1][3] == n[1][3] &&\n"
" m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2] && m[2][3] == n[2][3] &&\n"
" m[3][0] == n[3][0] && m[3][1] == n[3][1] && m[3][2] == n[3][2] && m[3][3] == n[3][3];\n"
"}\n";
}
if (mUsesEqualVec2)
{
out << "bool __equal(float2 v, float2 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y;\n"
"}\n";
}
if (mUsesEqualVec3)
{
out << "bool __equal(float3 v, float3 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y && v.z == u.z;\n"
"}\n";
}
if (mUsesEqualVec4)
{
out << "bool __equal(float4 v, float4 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
"}\n";
}
if (mUsesEqualIVec2)
{
out << "bool __equal(int2 v, int2 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y;\n"
"}\n";
}
if (mUsesEqualIVec3)
{
out << "bool __equal(int3 v, int3 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y && v.z == u.z;\n"
"}\n";
}
if (mUsesEqualIVec4)
{
out << "bool __equal(int4 v, int4 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
"}\n";
}
if (mUsesEqualBVec2)
{
out << "bool __equal(bool2 v, bool2 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y;\n"
"}\n";
}
if (mUsesEqualBVec3)
{
out << "bool __equal(bool3 v, bool3 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y && v.z == u.z;\n"
"}\n";
}
if (mUsesEqualBVec4)
{
out << "bool __equal(bool4 v, bool4 u)\n"
"{\n"
" return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
"}\n";
}
} }
void OutputHLSL::footer() void OutputHLSL::footer()
...@@ -642,23 +735,106 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -642,23 +735,106 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
case EOpMul: outputTriplet(visit, "(", " * ", ")"); break; case EOpMul: outputTriplet(visit, "(", " * ", ")"); break;
case EOpDiv: outputTriplet(visit, "(", " / ", ")"); break; case EOpDiv: outputTriplet(visit, "(", " / ", ")"); break;
case EOpEqual: case EOpEqual:
if (!node->getLeft()->isMatrix()) case EOpNotEqual:
{ if (node->getLeft()->isScalar())
outputTriplet(visit, "(", " == ", ")");
}
else
{ {
outputTriplet(visit, "__equal(", ", ", ")"); if (node->getOp() == EOpEqual)
{
outputTriplet(visit, "(", " == ", ")");
}
else
{
outputTriplet(visit, "(", " != ", ")");
}
} }
break; else if (node->getLeft()->getBasicType() == EbtStruct)
case EOpNotEqual:
if (!node->getLeft()->isMatrix())
{ {
outputTriplet(visit, "(", " != ", ")"); if (node->getOp() == EOpEqual)
{
out << "(";
}
else
{
out << "!(";
}
const TTypeList *fields = node->getLeft()->getType().getStruct();
for (size_t i = 0; i < fields->size(); i++)
{
const TType *fieldType = (*fields)[i].type;
node->getLeft()->traverse(this);
out << "." + fieldType->getFieldName() + " == ";
node->getRight()->traverse(this);
out << "." + fieldType->getFieldName();
if (i < fields->size() - 1)
{
out << " && ";
}
}
out << ")";
return false;
} }
else else
{ {
outputTriplet(visit, "!__equal(", ", ", ")"); if (node->getLeft()->isMatrix())
{
switch (node->getLeft()->getSize())
{
case 2 * 2: mUsesEqualMat2 = true; break;
case 3 * 3: mUsesEqualMat3 = true; break;
case 4 * 4: mUsesEqualMat4 = true; break;
default: UNREACHABLE();
}
}
else if (node->getLeft()->isVector())
{
switch (node->getLeft()->getBasicType())
{
case EbtFloat:
switch (node->getLeft()->getSize())
{
case 2: mUsesEqualVec2 = true; break;
case 3: mUsesEqualVec3 = true; break;
case 4: mUsesEqualVec4 = true; break;
default: UNREACHABLE();
}
break;
case EbtInt:
switch (node->getLeft()->getSize())
{
case 2: mUsesEqualIVec2 = true; break;
case 3: mUsesEqualIVec3 = true; break;
case 4: mUsesEqualIVec4 = true; break;
default: UNREACHABLE();
}
break;
case EbtBool:
switch (node->getLeft()->getSize())
{
case 2: mUsesEqualBVec2 = true; break;
case 3: mUsesEqualBVec3 = true; break;
case 4: mUsesEqualBVec4 = true; break;
default: UNREACHABLE();
}
break;
default: UNREACHABLE();
}
}
else UNREACHABLE();
if (node->getOp() == EOpEqual)
{
outputTriplet(visit, "__equal(", ", ", ")");
}
else
{
outputTriplet(visit, "!__equal(", ", ", ")");
}
} }
break; break;
case EOpLessThan: outputTriplet(visit, "(", " < ", ")"); break; case EOpLessThan: outputTriplet(visit, "(", " < ", ")"); break;
...@@ -803,12 +979,6 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -803,12 +979,6 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
EShLanguage language = mContext.language; EShLanguage language = mContext.language;
TInfoSinkBase &out = mBody; TInfoSinkBase &out = mBody;
if (node->getOp() == EOpNull)
{
out.message(EPrefixError, "node is still EOpNull!");
return true;
}
switch (node->getOp()) switch (node->getOp())
{ {
case EOpSequence: outputTriplet(visit, NULL, ";\n", ";\n"); break; case EOpSequence: outputTriplet(visit, NULL, ";\n", ";\n"); break;
...@@ -823,6 +993,11 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node) ...@@ -823,6 +993,11 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "") // Variable declaration if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "") // Variable declaration
{ {
if (variable->getQualifier() == EvqGlobal)
{
out << "static ";
}
out << typeString(variable->getType()) + " "; out << typeString(variable->getType()) + " ";
for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++) for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
...@@ -1144,10 +1319,7 @@ void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node) ...@@ -1144,10 +1319,7 @@ void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
default: UNREACHABLE(); default: UNREACHABLE();
} }
} }
else else UNREACHABLE();
{
UNIMPLEMENTED();
}
break; break;
case EbtFloat: case EbtFloat:
if (!matrix) if (!matrix)
...@@ -1184,10 +1356,7 @@ void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node) ...@@ -1184,10 +1356,7 @@ void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
default: UNREACHABLE(); default: UNREACHABLE();
} }
} }
else else UNREACHABLE();
{
UNIMPLEMENTED();
}
break; break;
default: default:
UNIMPLEMENTED(); // FIXME UNIMPLEMENTED(); // FIXME
......
...@@ -48,6 +48,20 @@ class OutputHLSL : public TIntermTraverser ...@@ -48,6 +48,20 @@ class OutputHLSL : public TIntermTraverser
TInfoSinkBase mHeader; TInfoSinkBase mHeader;
TInfoSinkBase mBody; TInfoSinkBase mBody;
TInfoSinkBase mFooter; TInfoSinkBase mFooter;
// Parameters determining what goes in the header output
bool mUsesEqualMat2;
bool mUsesEqualMat3;
bool mUsesEqualMat4;
bool mUsesEqualVec2;
bool mUsesEqualVec3;
bool mUsesEqualVec4;
bool mUsesEqualIVec2;
bool mUsesEqualIVec3;
bool mUsesEqualIVec4;
bool mUsesEqualBVec2;
bool mUsesEqualBVec3;
bool mUsesEqualBVec4;
}; };
} }
......
...@@ -214,6 +214,7 @@ public: ...@@ -214,6 +214,7 @@ public:
void setArrayInformationType(TType* t) { arrayInformationType = t; } void setArrayInformationType(TType* t) { arrayInformationType = t; }
TType* getArrayInformationType() const { return arrayInformationType; } TType* getArrayInformationType() const { return arrayInformationType; }
virtual bool isVector() const { return size > 1 && !matrix; } virtual bool isVector() const { return size > 1 && !matrix; }
virtual bool isScalar() const { return size == 1 && !matrix && !structure; }
static const char* getBasicString(TBasicType t) { static const char* getBasicString(TBasicType t) {
switch (t) { switch (t) {
case EbtVoid: return "void"; break; case EbtVoid: return "void"; break;
......
...@@ -243,6 +243,7 @@ public: ...@@ -243,6 +243,7 @@ public:
virtual bool isMatrix() const { return type.isMatrix(); } virtual bool isMatrix() const { return type.isMatrix(); }
virtual bool isArray() const { return type.isArray(); } virtual bool isArray() const { return type.isArray(); }
virtual bool isVector() const { return type.isVector(); } virtual bool isVector() const { return type.isVector(); }
virtual bool isScalar() const { return type.isScalar(); }
const char* getBasicString() const { return type.getBasicString(); } const char* getBasicString() const { return type.getBasicString(); }
const char* getQualifierString() const { return type.getQualifierString(); } const char* getQualifierString() const { return type.getQualifierString(); }
TString getCompleteString() const { return type.getCompleteString(); } TString getCompleteString() const { return type.getCompleteString(); }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment