Commit 196b6e24 by John Kessenich Committed by GitHub

Merge pull request #536 from steve-lunarg/flatten-assign-fix

HLSL: fix for flattening assignments from non-symbol R-values.
parents ce5d4afc 2199c240
struct PS_OUTPUT
{
float4 color : SV_Target0;
float other_struct_member1;
float other_struct_member2;
float other_struct_member3;
};
PS_OUTPUT Func1()
{
return PS_OUTPUT(float4(1), 2, 3, 4);
}
PS_OUTPUT main()
{
return Func1();
}
......@@ -73,6 +73,16 @@ TIntermSymbol* TIntermediate::addSymbol(int id, const TString& name, const TType
return node;
}
TIntermSymbol* TIntermediate::addSymbol(const TIntermSymbol& intermSymbol)
{
return addSymbol(intermSymbol.getId(),
intermSymbol.getName(),
intermSymbol.getType(),
intermSymbol.getConstArray(),
intermSymbol.getConstSubtree(),
intermSymbol.getLoc());
}
TIntermSymbol* TIntermediate::addSymbol(const TVariable& variable)
{
glslang::TSourceLoc loc; // just a null location
......
......@@ -201,6 +201,7 @@ public:
TIntermSymbol* addSymbol(const TVariable&);
TIntermSymbol* addSymbol(const TVariable&, const TSourceLoc&);
TIntermSymbol* addSymbol(const TType&, const TSourceLoc&);
TIntermSymbol* addSymbol(const TIntermSymbol&);
TIntermTyped* addConversion(TOperator, const TType&, TIntermTyped*) const;
TIntermTyped* addShapeConversion(TOperator, const TType&, TIntermTyped*);
TIntermTyped* addBinaryMath(TOperator, TIntermTyped* left, TIntermTyped* right, TSourceLoc);
......
......@@ -99,6 +99,7 @@ INSTANTIATE_TEST_CASE_P(
{"hlsl.entry-out.frag", "PixelShaderFunction"},
{"hlsl.float1.frag", "PixelShaderFunction"},
{"hlsl.float4.frag", "PixelShaderFunction"},
{"hlsl.flatten.return.frag", "main"},
{"hlsl.forLoop.frag", "PixelShaderFunction"},
{"hlsl.gather.array.dx10.frag", "main"},
{"hlsl.gather.basic.dx10.frag", "main"},
......
......@@ -952,10 +952,53 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
const TVector<TVariable*>* leftVariables = nullptr;
const TVector<TVariable*>* rightVariables = nullptr;
// A temporary to store the right node's value, so we don't keep indirecting into it
// if it's not a simple symbol.
TVariable* rhsTempVar = nullptr;
// If the RHS is a simple symbol node, we'll copy it for each member.
TIntermSymbol* cloneSymNode = nullptr;
// Array structs are not yet handled in flattening. (Compilation error upstream, so
// this should never fire).
assert(!(left->getType().isStruct() && left->getType().isArray()));
int memberCount = 0;
// Track how many items there are to copy.
if (left->getType().isStruct())
memberCount = left->getType().getStruct()->size();
if (left->getType().isArray())
memberCount = left->getType().getCumulativeArraySize();
if (flattenLeft)
leftVariables = &flattenMap.find(left->getAsSymbolNode()->getId())->second;
if (flattenRight)
if (flattenRight) {
rightVariables = &flattenMap.find(right->getAsSymbolNode()->getId())->second;
} else {
// The RHS is not flattened. There are several cases:
// 1. 1 item to copy: Use the RHS directly.
// 2. >1 item, simple symbol RHS: we'll create a new TIntermSymbol node for each, but no assign to temp.
// 3. >1 item, complex RHS: assign it to a new temp variable, and create a TIntermSymbol for each member.
if (memberCount <= 1) {
// case 1: we'll use the symbol directly below. Nothing to do.
} else {
if (right->getAsSymbolNode() != nullptr) {
// case 2: we'll copy the symbol per iteration below.
cloneSymNode = right->getAsSymbolNode();
} else {
// case 3: assign to a temp, and indirect into that.
rhsTempVar = makeInternalVariable("flattenTemp", right->getType());
rhsTempVar->getWritableType().getQualifier().makeTemporary();
TIntermTyped* noFlattenRHS = intermediate.addSymbol(*rhsTempVar, loc);
// Add this to the aggregate being built.
assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, noFlattenRHS, right, loc), loc);
}
}
}
const auto getMember = [&](bool flatten, TIntermTyped* node,
const TVector<TVariable*>& memberVariables, int member,
......@@ -971,6 +1014,14 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
return subTree;
};
// Return the proper RHS node: a new symbol from a TVariable, copy
// of an TIntermSymbol node, or sometimes the right node directly.
const auto getRHS = [&]() {
return rhsTempVar ? intermediate.addSymbol(*rhsTempVar, loc) :
cloneSymNode ? intermediate.addSymbol(*cloneSymNode) :
right;
};
// Handle struct assignment
if (left->getType().isStruct()) {
// If we get here, we are assigning to or from a whole struct that must be
......@@ -978,7 +1029,7 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
const auto& members = *left->getType().getStruct();
for (int member = 0; member < (int)members.size(); ++member) {
TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, member,
TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, member,
EOpIndexDirectStruct, *members[member].type);
TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, member,
EOpIndexDirectStruct, *members[member].type);
......@@ -992,10 +1043,10 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op
// flattened, so have to do member-by-member assignment:
const TType dereferencedType(left->getType(), 0);
const int size = left->getType().getCumulativeArraySize();
for (int element=0; element < size; ++element) {
TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, element,
for (int element=0; element < memberCount; ++element) {
// Add a new AST symbol node if we have a temp variable holding a complex RHS.
TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, element,
EOpIndexDirect, dereferencedType);
TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, element,
EOpIndexDirect, dereferencedType);
......@@ -1235,9 +1286,9 @@ void HlslParseContext::decomposeSampleMethods(const TSourceLoc& loc, TIntermType
// Return value from size query
TVariable* tempArg = makeInternalVariable("sizeQueryTemp", sizeQuery->getType());
tempArg->getWritableType().getQualifier().makeTemporary();
TIntermSymbol* sizeQueryReturn = intermediate.addSymbol(*tempArg, loc);
TIntermTyped* sizeQueryAssign = intermediate.addAssign(EOpAssign, sizeQueryReturn, sizeQuery, loc);
TIntermTyped* sizeQueryAssign = intermediate.addAssign(EOpAssign,
intermediate.addSymbol(*tempArg, loc),
sizeQuery, loc);
// Compound statement for assigning outputs
TIntermAggregate* compoundStatement = intermediate.makeAggregate(sizeQueryAssign, loc);
......@@ -1246,6 +1297,7 @@ void HlslParseContext::decomposeSampleMethods(const TSourceLoc& loc, TIntermType
for (int compNum = 0; compNum < numDims; ++compNum) {
TIntermTyped* indexedOut = nullptr;
TIntermSymbol* sizeQueryReturn = intermediate.addSymbol(*tempArg, loc);
if (numDims > 1) {
TIntermTyped* component = intermediate.addConstantUnion(compNum, loc, true);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment