Commit acb9076a by John Kessenich Committed by GitHub

Merge pull request #650 from steve-lunarg/lvalue-swizzle-fix

HLSL: allow destination swizzles when writing RWTexture/RWBuffer
parents 085b8334 cd6829ba
RWTexture2D<float3> rwtx;
RWBuffer<float3> buf;
float3 SomeValue() { return float3(1,2,3); }
float4 main() : SV_Target0
{
int2 tc2 = { 0, 0 };
int tc = 0;
// Test swizzles and partial updates of L-values when writing to buffers and writable textures.
rwtx[tc2].zyx = float3(1,2,3); // full swizzle, simple RHS
rwtx[tc2].zyx = SomeValue(); // full swizzle, complex RHS
rwtx[tc2].zyx = 2; // full swizzle, modify op
// Partial updates not yet supported.
// Partial values, which will use swizzles.
// buf[tc].yz = 42; // partial swizzle, simple RHS
// buf[tc].yz = SomeValue().x; // partial swizzle, complex RHS
// buf[tc].yz += 43; // partial swizzle, modify op
// // Partial values, which will use index.
// buf[tc].y = 44; // single index, simple RHS
// buf[tc].y = SomeValue().x; // single index, complex RHS
// buf[tc].y += 45; // single index, modify op
return 0.0;
}
...@@ -174,6 +174,7 @@ INSTANTIATE_TEST_CASE_P( ...@@ -174,6 +174,7 @@ INSTANTIATE_TEST_CASE_P(
{"hlsl.rw.bracket.frag", "main"}, {"hlsl.rw.bracket.frag", "main"},
{"hlsl.rw.register.frag", "main"}, {"hlsl.rw.register.frag", "main"},
{"hlsl.rw.scalar.bracket.frag", "main"}, {"hlsl.rw.scalar.bracket.frag", "main"},
{"hlsl.rw.swizzle.frag", "main"},
{"hlsl.rw.vec2.bracket.frag", "main"}, {"hlsl.rw.vec2.bracket.frag", "main"},
{"hlsl.sample.array.dx10.frag", "main"}, {"hlsl.sample.array.dx10.frag", "main"},
{"hlsl.sample.basic.dx10.frag", "main"}, {"hlsl.sample.basic.dx10.frag", "main"},
......
...@@ -47,6 +47,7 @@ ...@@ -47,6 +47,7 @@
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <cctype> #include <cctype>
#include <array>
namespace glslang { namespace glslang {
...@@ -147,6 +148,12 @@ bool HlslParseContext::shouldConvertLValue(const TIntermNode* node) const ...@@ -147,6 +148,12 @@ bool HlslParseContext::shouldConvertLValue(const TIntermNode* node) const
return false; return false;
const TIntermAggregate* lhsAsAggregate = node->getAsAggregate(); const TIntermAggregate* lhsAsAggregate = node->getAsAggregate();
const TIntermBinary* lhsAsBinary = node->getAsBinaryNode();
// If it's a swizzled/indexed aggregate, look at the left node instead.
if (lhsAsBinary != nullptr &&
(lhsAsBinary->getOp() == EOpVectorSwizzle || lhsAsBinary->getOp() == EOpIndexDirect))
lhsAsAggregate = lhsAsBinary->getLeft()->getAsAggregate();
if (lhsAsAggregate != nullptr && lhsAsAggregate->getOp() == EOpImageLoad) if (lhsAsAggregate != nullptr && lhsAsAggregate->getOp() == EOpImageLoad)
return true; return true;
...@@ -285,6 +292,34 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char* ...@@ -285,6 +292,34 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char*
loc); loc);
}; };
// Return true if swizzle or index writes all components of the given variable.
const auto writesAllComponents = [&](TIntermSymbol* var, TIntermBinary* swizzle) -> bool {
if (swizzle == nullptr) // not a swizzle or index
return true;
// Track which components are being set.
std::array<bool, 4> compIsSet;
compIsSet.fill(false);
const TIntermConstantUnion* asConst = swizzle->getRight()->getAsConstantUnion();
const TIntermAggregate* asAggregate = swizzle->getRight()->getAsAggregate();
// This could be either a direct index, or a swizzle.
if (asConst) {
compIsSet[asConst->getConstArray()[0].getIConst()] = true;
} else if (asAggregate) {
const TIntermSequence& seq = asAggregate->getSequence();
for (int comp=0; comp<int(seq.size()); ++comp)
compIsSet[seq[comp]->getAsConstantUnion()->getConstArray()[0].getIConst()] = true;
} else {
assert(0);
}
// Return true if all components are being set by the index or swizzle
return std::all_of(compIsSet.begin(), compIsSet.begin() + var->getType().getVectorSize(),
[](bool isSet) { return isSet; } );
};
// helper to create a temporary variable // helper to create a temporary variable
const auto addTmpVar = [&](const char* name, const TType& derefType) -> TIntermSymbol* { const auto addTmpVar = [&](const char* name, const TType& derefType) -> TIntermSymbol* {
TVariable* tmpVar = makeInternalVariable(name, derefType); TVariable* tmpVar = makeInternalVariable(name, derefType);
...@@ -292,7 +327,24 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char* ...@@ -292,7 +327,24 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char*
return intermediate.addSymbol(*tmpVar, loc); return intermediate.addSymbol(*tmpVar, loc);
}; };
// Create swizzle matching input swizzle
const auto addSwizzle = [&](TIntermSymbol* var, TIntermBinary* swizzle) -> TIntermTyped* {
if (swizzle)
return intermediate.addBinaryNode(swizzle->getOp(), var, swizzle->getRight(), loc, swizzle->getType());
else
return var;
};
TIntermBinary* lhsAsBinary = lhs->getAsBinaryNode();
TIntermAggregate* lhsAsAggregate = lhs->getAsAggregate(); TIntermAggregate* lhsAsAggregate = lhs->getAsAggregate();
bool lhsIsSwizzle = false;
// If it's a swizzled L-value, remember the swizzle, and use the LHS.
if (lhsAsBinary != nullptr && (lhsAsBinary->getOp() == EOpVectorSwizzle || lhsAsBinary->getOp() == EOpIndexDirect)) {
lhsAsAggregate = lhsAsBinary->getLeft()->getAsAggregate();
lhsIsSwizzle = true;
}
TIntermTyped* object = lhsAsAggregate->getSequence()[0]->getAsTyped(); TIntermTyped* object = lhsAsAggregate->getSequence()[0]->getAsTyped();
TIntermTyped* coord = lhsAsAggregate->getSequence()[1]->getAsTyped(); TIntermTyped* coord = lhsAsAggregate->getSequence()[1]->getAsTyped();
...@@ -339,16 +391,26 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char* ...@@ -339,16 +391,26 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char*
// OpSequence // OpSequence
// coordtmp = load's param1 // coordtmp = load's param1
// rhsTmp = OpImageLoad(object, coordTmp) // rhsTmp = OpImageLoad(object, coordTmp)
// rhsTmp op= rhs // rhsTmp op = rhs
// OpImageStore(object, coordTmp, rhsTmp) // OpImageStore(object, coordTmp, rhsTmp)
// rhsTmp // rhsTmp
//
// If the lvalue is swizzled, we apply that when writing the temp variable, like so:
// ...
// rhsTmp.some_swizzle = ...
// For partial writes, an error is generated.
TIntermSymbol* rhsTmp = rhs->getAsSymbolNode(); TIntermSymbol* rhsTmp = rhs->getAsSymbolNode();
TIntermTyped* coordTmp = coord; TIntermTyped* coordTmp = coord;
if (rhsTmp == nullptr || isModifyOp) { if (rhsTmp == nullptr || isModifyOp || lhsIsSwizzle) {
rhsTmp = addTmpVar("storeTemp", objDerefType); rhsTmp = addTmpVar("storeTemp", objDerefType);
// Partial updates not yet supported
if (!writesAllComponents(rhsTmp, lhsAsBinary)) {
error(loc, "unimplemented: partial image updates", "", "");
}
// Assign storeTemp = rhs // Assign storeTemp = rhs
if (isModifyOp) { if (isModifyOp) {
// We have to make a temp var for the coordinate, to avoid evaluating it twice. // We have to make a temp var for the coordinate, to avoid evaluating it twice.
...@@ -358,7 +420,7 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char* ...@@ -358,7 +420,7 @@ TIntermTyped* HlslParseContext::handleLvalue(const TSourceLoc& loc, const char*
} }
// rhsTmp op= rhs. // rhsTmp op= rhs.
makeBinary(assignOp, intermediate.addSymbol(*rhsTmp), rhs); makeBinary(assignOp, addSwizzle(intermediate.addSymbol(*rhsTmp), lhsAsBinary), rhs);
} }
makeStore(object, coordTmp, rhsTmp); // add a store makeStore(object, coordTmp, rhsTmp); // add a store
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment