Commit 7fb4955d by Olli Etuaho

Support array equality operator in HLSL output

This requires adding functions to the shader source that can do the comparison for a specific array size. There's no automated test coverage specifically for this functionality, since all deqp tests that cover this also require array constructors to be supported. The change has been tested by manually inspecting shader output. No regressions were seen in automated tests listed below. TEST=dEQP-GLES3.functional.shaders.*, angle_unittests BUG=angleproject:941 Change-Id: Ie2ca7c016a3f0bcb3392a96d6d20d6f803d28bf0 Reviewed-on: https://chromium-review.googlesource.com/261530Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarNicolas Capens <capn@chromium.org> Tested-by: 's avatarOlli Etuaho <oetuaho@nvidia.com>
parent 8fcd4e0c
...@@ -338,6 +338,15 @@ void OutputHLSL::header(const BuiltInFunctionEmulator *builtInFunctionEmulator) ...@@ -338,6 +338,15 @@ void OutputHLSL::header(const BuiltInFunctionEmulator *builtInFunctionEmulator)
} }
} }
if (!mArrayEqualityFunctions.empty())
{
out << "\n// Array equality functions\n\n";
for (const auto &eqFunction : mArrayEqualityFunctions)
{
out << eqFunction.functionDefinition << "\n";
}
}
if (mUsesDiscardRewriting) if (mUsesDiscardRewriting)
{ {
out << "#define ANGLE_USES_DISCARD_REWRITING\n"; out << "#define ANGLE_USES_DISCARD_REWRITING\n";
...@@ -1380,6 +1389,45 @@ void OutputHLSL::visitRaw(TIntermRaw *node) ...@@ -1380,6 +1389,45 @@ void OutputHLSL::visitRaw(TIntermRaw *node)
getInfoSink() << node->getRawText(); getInfoSink() << node->getRawText();
} }
void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
{
if (type.isScalar() && !type.isArray())
{
if (op == EOpEqual)
{
outputTriplet(visit, "(", " == ", ")", out);
}
else
{
outputTriplet(visit, "(", " != ", ")", out);
}
}
else
{
if (visit == PreVisit && op == EOpNotEqual)
{
out << "!";
}
if (type.isArray())
{
const TString &functionName = addArrayEqualityFunction(type);
outputTriplet(visit, (functionName + "(").c_str(), ", ", ")", out);
}
else if (type.getBasicType() == EbtStruct)
{
const TStructure &structure = *type.getStruct();
const TString &functionName = addStructEqualityFunction(structure);
outputTriplet(visit, (functionName + "(").c_str(), ", ", ")", out);
}
else
{
ASSERT(type.isMatrix() || type.isVector());
outputTriplet(visit, "all(", " == ", ")", out);
}
}
}
bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
...@@ -1574,40 +1622,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1574,40 +1622,7 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
case EOpBitwiseOr: outputTriplet(visit, "(", " | ", ")"); break; case EOpBitwiseOr: outputTriplet(visit, "(", " | ", ")"); break;
case EOpEqual: case EOpEqual:
case EOpNotEqual: case EOpNotEqual:
if (node->getLeft()->isArray()) outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
{
UNIMPLEMENTED();
}
else if (node->getLeft()->isScalar())
{
if (node->getOp() == EOpEqual)
{
outputTriplet(visit, "(", " == ", ")");
}
else
{
outputTriplet(visit, "(", " != ", ")");
}
}
else
{
if (visit == PreVisit && node->getOp() == EOpNotEqual)
{
out << "!";
}
if (node->getLeft()->getBasicType() == EbtStruct)
{
const TStructure &structure = *node->getLeft()->getType().getStruct();
const TString &functionName = addStructEqualityFunction(structure);
outputTriplet(visit, (functionName + "(").c_str(), ", ", ")");
}
else
{
ASSERT(node->getLeft()->isMatrix() || node->getLeft()->isVector());
outputTriplet(visit, "all(", " == ", ")");
}
}
break; break;
case EOpLessThan: outputTriplet(visit, "(", " < ", ")"); break; case EOpLessThan: outputTriplet(visit, "(", " < ", ")"); break;
case EOpGreaterThan: outputTriplet(visit, "(", " > ", ")"); break; case EOpGreaterThan: outputTriplet(visit, "(", " > ", ")"); break;
...@@ -2723,10 +2738,8 @@ bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node) ...@@ -2723,10 +2738,8 @@ bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
return false; // Not handled as an excessive loop return false; // Not handled as an excessive loop
} }
void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString) void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString, TInfoSinkBase &out)
{ {
TInfoSinkBase &out = getInfoSink();
if (visit == PreVisit) if (visit == PreVisit)
{ {
out << preString; out << preString;
...@@ -2741,6 +2754,11 @@ void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *i ...@@ -2741,6 +2754,11 @@ void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *i
} }
} }
void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString)
{
outputTriplet(visit, preString, inString, postString, getInfoSink());
}
void OutputHLSL::outputLineDirective(int line) void OutputHLSL::outputLineDirective(int line)
{ {
if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0)) if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
...@@ -2969,12 +2987,19 @@ TString OutputHLSL::addStructEqualityFunction(const TStructure &structure) ...@@ -2969,12 +2987,19 @@ TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
const TString &fieldNameA = "a." + Decorate(field->name()); const TString &fieldNameA = "a." + Decorate(field->name());
const TString &fieldNameB = "b." + Decorate(field->name()); const TString &fieldNameB = "b." + Decorate(field->name());
// TODO (oetuaho): Use outputEqual() here instead
if (i > 0) if (i > 0)
{ {
func += " && "; func += " && ";
} }
if (fieldType->getBasicType() == EbtStruct) if (fieldType->isArray())
{
// TODO (oetuaho): This requires sorting array and struct equality functions together.
UNIMPLEMENTED();
}
else if (fieldType->getBasicType() == EbtStruct)
{ {
const TStructure &fieldStruct = *fieldType->getStruct(); const TStructure &fieldStruct = *fieldType->getStruct();
const TString &functionName = addStructEqualityFunction(fieldStruct); const TString &functionName = addStructEqualityFunction(fieldStruct);
...@@ -2998,4 +3023,54 @@ TString OutputHLSL::addStructEqualityFunction(const TStructure &structure) ...@@ -2998,4 +3023,54 @@ TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
return function.functionName; return function.functionName;
} }
TString OutputHLSL::addArrayEqualityFunction(const TType& type)
{
for (const auto &eqFunction : mArrayEqualityFunctions)
{
if (eqFunction.type == type)
{
return eqFunction.functionName;
}
}
const TString &typeName = TypeString(type);
ArrayEqualityFunction function;
function.type = type;
TInfoSinkBase fnNameOut;
fnNameOut << "angle_eq_" << type.getArraySize() << "_" << typeName;
function.functionName = fnNameOut.c_str();
TType nonArrayType = type;
nonArrayType.clearArrayness();
TInfoSinkBase fnOut;
fnOut << "bool " << function.functionName << "("
<< typeName << "[" << type.getArraySize() << "] a, "
<< typeName << "[" << type.getArraySize() << "] b)\n"
<< "{\n"
" for (int i = 0; i < " << type.getArraySize() << "; ++i)\n"
" {\n"
" if (";
outputEqual(PreVisit, nonArrayType, EOpNotEqual, fnOut);
fnOut << "a[i]";
outputEqual(InVisit, nonArrayType, EOpNotEqual, fnOut);
fnOut << "b[i]";
outputEqual(PostVisit, nonArrayType, EOpNotEqual, fnOut);
fnOut << ") { return false; }\n"
" }\n"
" return true;\n"
"}\n";
function.functionDefinition = fnOut.c_str();
mArrayEqualityFunctions.push_back(function);
return function.functionName;
}
} }
...@@ -67,6 +67,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -67,6 +67,7 @@ class OutputHLSL : public TIntermTraverser
bool handleExcessiveLoop(TIntermLoop *node); bool handleExcessiveLoop(TIntermLoop *node);
// Emit one of three strings depending on traverse phase. Called with literal strings so using const char* instead of TString. // Emit one of three strings depending on traverse phase. Called with literal strings so using const char* instead of TString.
void outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString, TInfoSinkBase &out);
void outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString); void outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString);
void outputLineDirective(int line); void outputLineDirective(int line);
TString argumentString(const TIntermSymbol *symbol); TString argumentString(const TIntermSymbol *symbol);
...@@ -76,6 +77,8 @@ class OutputHLSL : public TIntermTraverser ...@@ -76,6 +77,8 @@ class OutputHLSL : public TIntermTraverser
void outputConstructor(Visit visit, const TType &type, const char *name, const TIntermSequence *parameters); void outputConstructor(Visit visit, const TType &type, const char *name, const TIntermSequence *parameters);
const ConstantUnion *writeConstantUnion(const TType &type, const ConstantUnion *constUnion); const ConstantUnion *writeConstantUnion(const TType &type, const ConstantUnion *constUnion);
void outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out);
void writeEmulatedFunctionTriplet(Visit visit, const char *preStr); void writeEmulatedFunctionTriplet(Visit visit, const char *preStr);
void makeFlaggedStructMaps(const std::vector<TIntermTyped *> &flaggedStructs); void makeFlaggedStructMaps(const std::vector<TIntermTyped *> &flaggedStructs);
...@@ -85,6 +88,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -85,6 +88,7 @@ class OutputHLSL : public TIntermTraverser
// Returns the function name // Returns the function name
TString addStructEqualityFunction(const TStructure &structure); TString addStructEqualityFunction(const TStructure &structure);
TString addArrayEqualityFunction(const TType &type);
sh::GLenum mShaderType; sh::GLenum mShaderType;
int mShaderVersion; int mShaderVersion;
...@@ -103,6 +107,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -103,6 +107,7 @@ class OutputHLSL : public TIntermTraverser
// A stack is useful when we want to traverse in the header, or in helper functions, but not always // A stack is useful when we want to traverse in the header, or in helper functions, but not always
// write to the body. Instead use an InfoSink stack to keep our current state intact. // write to the body. Instead use an InfoSink stack to keep our current state intact.
// TODO (jmadill): Just passing an InfoSink in function parameters would be simpler.
std::stack<TInfoSinkBase *> mInfoSinkStack; std::stack<TInfoSinkBase *> mInfoSinkStack;
ReferencedSymbols mReferencedUniforms; ReferencedSymbols mReferencedUniforms;
...@@ -191,6 +196,15 @@ class OutputHLSL : public TIntermTraverser ...@@ -191,6 +196,15 @@ class OutputHLSL : public TIntermTraverser
}; };
std::vector<StructEqualityFunction> mStructEqualityFunctions; std::vector<StructEqualityFunction> mStructEqualityFunctions;
struct ArrayEqualityFunction
{
TType type;
TString functionName;
TString functionDefinition;
};
std::vector<ArrayEqualityFunction> mArrayEqualityFunctions;
}; };
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment