Commit e752f463 by steve-lunarg

HLSL: HS return is arrayed to match SPIR-V semantics

HLSL HS outputs a per ctrl point value, and the DS reads an array of that type. (It also has a per patch frequency). The per-ctrl-pt frequency is arrayed on just one side, as opposed to SPIR-V which is arrayed on both. To match semantics, the compiler creates an array behind the scenes and indexes it by invocation ID, assigning the HS return value to it.
parent 7afe1344
...@@ -27,7 +27,9 @@ vertices = 3 ...@@ -27,7 +27,9 @@ vertices = 3
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:26 move second child to first child ( temp structure{ temp 3-component vector of float cpoint}) 0:26 move second child to first child ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint}) 0:26 indirect index ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID)
0:26 Function Call: @main(struct-VS_OUT-vf31[3]; ( temp structure{ temp 3-component vector of float cpoint}) 0:26 Function Call: @main(struct-VS_OUT-vf31[3]; ( temp structure{ temp 3-component vector of float cpoint})
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? Barrier ( temp void) 0:? Barrier ( temp void)
...@@ -43,7 +45,7 @@ vertices = 3 ...@@ -43,7 +45,7 @@ vertices = 3
0:33 Function Definition: PCF( ( temp void) 0:33 Function Definition: PCF( ( temp void)
0:33 Function Parameters: 0:33 Function Parameters:
0:? Linker Objects 0:? Linker Objects
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint}) 0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID) 0:? 'InvocationId' ( in uint InvocationID)
...@@ -79,7 +81,9 @@ vertices = 3 ...@@ -79,7 +81,9 @@ vertices = 3
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:26 move second child to first child ( temp structure{ temp 3-component vector of float cpoint}) 0:26 move second child to first child ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint}) 0:26 indirect index ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID)
0:26 Function Call: @main(struct-VS_OUT-vf31[3]; ( temp structure{ temp 3-component vector of float cpoint}) 0:26 Function Call: @main(struct-VS_OUT-vf31[3]; ( temp structure{ temp 3-component vector of float cpoint})
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? Barrier ( temp void) 0:? Barrier ( temp void)
...@@ -95,18 +99,18 @@ vertices = 3 ...@@ -95,18 +99,18 @@ vertices = 3
0:33 Function Definition: PCF( ( temp void) 0:33 Function Definition: PCF( ( temp void)
0:33 Function Parameters: 0:33 Function Parameters:
0:? Linker Objects 0:? Linker Objects
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint}) 0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint}) 0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID) 0:? 'InvocationId' ( in uint InvocationID)
// Module Version 10000 // Module Version 10000
// Generated by (magic number): 80001 // Generated by (magic number): 80001
// Id's are bound by 51 // Id's are bound by 55
Capability Tessellation Capability Tessellation
1: ExtInstImport "GLSL.std.450" 1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450 MemoryModel Logical GLSL450
EntryPoint TessellationControl 4 "main" 33 36 44 EntryPoint TessellationControl 4 "main" 33 37 39
ExecutionMode 4 OutputVertices 3 ExecutionMode 4 OutputVertices 3
Name 4 "main" Name 4 "main"
Name 8 "VS_OUT" Name 8 "VS_OUT"
...@@ -119,12 +123,12 @@ vertices = 3 ...@@ -119,12 +123,12 @@ vertices = 3
Name 21 "output" Name 21 "output"
Name 31 "ip" Name 31 "ip"
Name 33 "ip" Name 33 "ip"
Name 36 "@entryPointOutput" Name 37 "@entryPointOutput"
Name 37 "param" Name 39 "InvocationId"
Name 44 "InvocationId" Name 41 "param"
Decorate 33(ip) Location 0 Decorate 33(ip) Location 0
Decorate 36(@entryPointOutput) Location 0 Decorate 37(@entryPointOutput) Location 0
Decorate 44(InvocationId) BuiltIn InvocationId Decorate 39(InvocationId) BuiltIn InvocationId
2: TypeVoid 2: TypeVoid
3: TypeFunction 2 3: TypeFunction 2
6: TypeFloat 32 6: TypeFloat 32
...@@ -142,33 +146,37 @@ vertices = 3 ...@@ -142,33 +146,37 @@ vertices = 3
24: TypePointer Function 7(fvec3) 24: TypePointer Function 7(fvec3)
32: TypePointer Input 11 32: TypePointer Input 11
33(ip): 32(ptr) Variable Input 33(ip): 32(ptr) Variable Input
35: TypePointer Output 13(HS_OUT) 35: TypeArray 13(HS_OUT) 10
36(@entryPointOutput): 35(ptr) Variable Output 36: TypePointer Output 35
40: 9(int) Constant 2 37(@entryPointOutput): 36(ptr) Variable Output
41: 9(int) Constant 1 38: TypePointer Input 9(int)
42: 9(int) Constant 0 39(InvocationId): 38(ptr) Variable Input
43: TypePointer Input 9(int) 44: TypePointer Output 13(HS_OUT)
44(InvocationId): 43(ptr) Variable Input 46: 9(int) Constant 2
46: TypeBool 47: 9(int) Constant 1
48: 9(int) Constant 0
50: TypeBool
4(main): 2 Function None 3 4(main): 2 Function None 3
5: Label 5: Label
31(ip): 12(ptr) Variable Function 31(ip): 12(ptr) Variable Function
37(param): 12(ptr) Variable Function 41(param): 12(ptr) Variable Function
34: 11 Load 33(ip) 34: 11 Load 33(ip)
Store 31(ip) 34 Store 31(ip) 34
38: 11 Load 31(ip) 40: 9(int) Load 39(InvocationId)
Store 37(param) 38 42: 11 Load 31(ip)
39: 13(HS_OUT) FunctionCall 16(@main(struct-VS_OUT-vf31[3];) 37(param) Store 41(param) 42
Store 36(@entryPointOutput) 39 43: 13(HS_OUT) FunctionCall 16(@main(struct-VS_OUT-vf31[3];) 41(param)
ControlBarrier 40 41 42 45: 44(ptr) AccessChain 37(@entryPointOutput) 40
45: 9(int) Load 44(InvocationId) Store 45 43
47: 46(bool) IEqual 45 23 ControlBarrier 46 47 48
SelectionMerge 49 None 49: 9(int) Load 39(InvocationId)
BranchConditional 47 48 49 51: 50(bool) IEqual 49 23
48: Label SelectionMerge 53 None
50: 2 FunctionCall 18(PCF() BranchConditional 51 52 53
Branch 49 52: Label
49: Label 54: 2 FunctionCall 18(PCF()
Branch 53
53: Label
Return Return
FunctionEnd FunctionEnd
16(@main(struct-VS_OUT-vf31[3];): 13(HS_OUT) Function None 14 16(@main(struct-VS_OUT-vf31[3];): 13(HS_OUT) Function None 14
......
...@@ -1598,48 +1598,10 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l ...@@ -1598,48 +1598,10 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l
return paramNodes; return paramNodes;
} }
//
// Do all special handling for the entry point, including wrapping
// the shader's entry point with the official entry point that will call it.
//
// The following:
//
// retType shaderEntryPoint(args...) // shader declared entry point
// { body }
//
// Becomes
//
// out retType ret;
// in iargs<that are input>...;
// out oargs<that are output> ...;
//
// void shaderEntryPoint() // synthesized, but official, entry point
// {
// args<that are input> = iargs...;
// ret = @shaderEntryPoint(args...);
// oargs = args<that are output>...;
// }
//
// The symbol table will still map the original entry point name to the
// the modified function and it's new name:
//
// symbol table: shaderEntryPoint -> @shaderEntryPoint
//
// Returns nullptr if no entry-point tree was built, otherwise, returns
// a subtree that creates the entry point.
//
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// if we aren't in the entry point, fix the IO as such and exit
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
remapNonEntryPointIO(userFunction);
return nullptr;
}
entryPointFunction = &userFunction; // needed in finish()
// entry point logic...
// Handle all [attrib] attribute for the shader entry point
void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// Handle entry-point function attributes // Handle entry-point function attributes
const TIntermAggregate* numThreads = attributes[EatNumThreads]; const TIntermAggregate* numThreads = attributes[EatNumThreads];
if (numThreads != nullptr) { if (numThreads != nullptr) {
...@@ -1691,7 +1653,11 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct ...@@ -1691,7 +1653,11 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
error(loc, "unsupported domain type", domainStr.c_str(), ""); error(loc, "unsupported domain type", domainStr.c_str(), "");
} }
if (! intermediate.setInputPrimitive(domain)) { if (language == EShLangTessEvaluation) {
if (! intermediate.setInputPrimitive(domain))
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
} else {
if (! intermediate.setOutputPrimitive(domain))
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), ""); error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
} }
} }
...@@ -1770,6 +1736,52 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct ...@@ -1770,6 +1736,52 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
} }
} }
} }
}
//
// Do all special handling for the entry point, including wrapping
// the shader's entry point with the official entry point that will call it.
//
// The following:
//
// retType shaderEntryPoint(args...) // shader declared entry point
// { body }
//
// Becomes
//
// out retType ret;
// in iargs<that are input>...;
// out oargs<that are output> ...;
//
// void shaderEntryPoint() // synthesized, but official, entry point
// {
// args<that are input> = iargs...;
// ret = @shaderEntryPoint(args...);
// oargs = args<that are output>...;
// }
//
// The symbol table will still map the original entry point name to the
// the modified function and it's new name:
//
// symbol table: shaderEntryPoint -> @shaderEntryPoint
//
// Returns nullptr if no entry-point tree was built, otherwise, returns
// a subtree that creates the entry point.
//
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// if we aren't in the entry point, fix the IO as such and exit
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
remapNonEntryPointIO(userFunction);
return nullptr;
}
entryPointFunction = &userFunction; // needed in finish()
// Handle entry point attributes
handleEntryPointAttributes(loc, userFunction, attributes);
// entry point logic...
// Move parameters and return value to shader in/out // Move parameters and return value to shader in/out
TVariable* entryPointOutput; // gets created in remapEntryPointIO TVariable* entryPointOutput; // gets created in remapEntryPointIO
...@@ -1838,10 +1850,37 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct ...@@ -1838,10 +1850,37 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
currentCaller = userFunction.getMangledName(); currentCaller = userFunction.getMangledName();
// Return value // Return value
if (entryPointOutput) if (entryPointOutput) {
intermediate.growAggregate(synthBody, handleAssign(loc, EOpAssign, TIntermTyped* returnAssign;
intermediate.addSymbol(*entryPointOutput), callReturn));
else if (language == EShLangTessControl) {
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
// If there is no user declared invocation ID, we must make one.
if (invocationIdSym == nullptr) {
TType invocationIdType(EbtUint, EvqIn, 1);
TString* invocationIdName = NewPoolTString("InvocationId");
invocationIdType.getQualifier().builtIn = EbvInvocationId;
TVariable* variable = makeInternalVariable(*invocationIdName, invocationIdType);
globalQualifierFix(loc, variable->getWritableType().getQualifier());
trackLinkage(*variable);
invocationIdSym = intermediate.addSymbol(*variable);
}
TIntermTyped* element = intermediate.addIndex(EOpIndexIndirect, intermediate.addSymbol(*entryPointOutput),
invocationIdSym, loc);
element->setType(callReturn->getType());
returnAssign = handleAssign(loc, EOpAssign, element, callReturn);
} else {
returnAssign = handleAssign(loc, EOpAssign, intermediate.addSymbol(*entryPointOutput), callReturn);
}
intermediate.growAggregate(synthBody, returnAssign);
} else
intermediate.growAggregate(synthBody, callReturn); intermediate.growAggregate(synthBody, callReturn);
// Output copies // Output copies
...@@ -1914,11 +1953,30 @@ void HlslParseContext::remapEntryPointIO(TFunction& function, TVariable*& return ...@@ -1914,11 +1953,30 @@ void HlslParseContext::remapEntryPointIO(TFunction& function, TVariable*& return
}; };
// return value is actually a shader-scoped output (out) // return value is actually a shader-scoped output (out)
if (function.getType().getBasicType() == EbtVoid) if (function.getType().getBasicType() == EbtVoid) {
returnValue = nullptr; returnValue = nullptr;
else } else {
if (language == EShLangTessControl) {
// tessellation evaluation in HLSL writes a per-ctrl-pt value, but it needs to be an
// array in SPIR-V semantics. We'll write to it indexed by invocation ID.
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut); returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
TType outputType;
outputType.shallowCopy(function.getType());
// vertices has necessarily already been set when handling entry point attributes.
TArraySizes arraySizes;
arraySizes.addInnerSize(intermediate.getVertices());
outputType.newArraySizes(arraySizes);
clearUniformInputOutput(function.getWritableType().getQualifier());
returnValue = makeIoVariable("@entryPointOutput", outputType, EvqVaryingOut);
} else {
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
}
}
// parameters are actually shader-scoped inputs and outputs (in or out) // parameters are actually shader-scoped inputs and outputs (in or out)
for (int i = 0; i < function.getParamCount(); i++) { for (int i = 0; i < function.getParamCount(); i++) {
TType& paramType = *function[i].type; TType& paramType = *function[i].type;
...@@ -7410,6 +7468,17 @@ void HlslParseContext::clearUniformInputOutput(TQualifier& qualifier) ...@@ -7410,6 +7468,17 @@ void HlslParseContext::clearUniformInputOutput(TQualifier& qualifier)
correctUniform(qualifier); correctUniform(qualifier);
} }
// Return a symbol for the linkage variable of the given TBuiltInVariable type
TIntermSymbol* HlslParseContext::findLinkageSymbol(TBuiltInVariable biType) const
{
const auto it = builtInLinkageSymbols.find(biType);
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
return nullptr;
return intermediate.addSymbol(*it->second->getAsVariable());
}
// Add patch constant function invocation // Add patch constant function invocation
void HlslParseContext::addPatchConstantInvocation() void HlslParseContext::addPatchConstantInvocation()
{ {
...@@ -7481,15 +7550,6 @@ void HlslParseContext::addPatchConstantInvocation() ...@@ -7481,15 +7550,6 @@ void HlslParseContext::addPatchConstantInvocation()
} }
}; };
// Return a symbol for the linkage variable of the given TBuiltInVariable type
const auto findLinkageSymbol = [this](TBuiltInVariable biType) -> TIntermSymbol* {
const auto it = builtInLinkageSymbols.find(biType);
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
return nullptr;
return intermediate.addSymbol(*it->second->getAsVariable());
};
const auto isPerCtrlPt = [this](const TType& type) { const auto isPerCtrlPt = [this](const TType& type) {
// TODO: this is not sufficient to reject all such cases in malformed shaders. // TODO: this is not sufficient to reject all such cases in malformed shaders.
return type.isArray() && !type.isRuntimeSizedArray(); return type.isArray() && !type.isRuntimeSizedArray();
......
...@@ -80,6 +80,7 @@ public: ...@@ -80,6 +80,7 @@ public:
void handleFunctionDeclarator(const TSourceLoc&, TFunction& function, bool prototype); void handleFunctionDeclarator(const TSourceLoc&, TFunction& function, bool prototype);
TIntermAggregate* handleFunctionDefinition(const TSourceLoc&, TFunction&, const TAttributeMap&, TIntermNode*& entryPointTree); TIntermAggregate* handleFunctionDefinition(const TSourceLoc&, TFunction&, const TAttributeMap&, TIntermNode*& entryPointTree);
TIntermNode* transformEntryPoint(const TSourceLoc&, TFunction&, const TAttributeMap&); TIntermNode* transformEntryPoint(const TSourceLoc&, TFunction&, const TAttributeMap&);
void handleEntryPointAttributes(const TSourceLoc&, TFunction&, const TAttributeMap&);
void handleFunctionBody(const TSourceLoc&, TFunction&, TIntermNode* functionBody, TIntermNode*& node); void handleFunctionBody(const TSourceLoc&, TFunction&, TIntermNode* functionBody, TIntermNode*& node);
void remapEntryPointIO(TFunction& function, TVariable*& returnValue, TVector<TVariable*>& inputs, TVector<TVariable*>& outputs); void remapEntryPointIO(TFunction& function, TVariable*& returnValue, TVector<TVariable*>& inputs, TVector<TVariable*>& outputs);
void remapNonEntryPointIO(TFunction& function); void remapNonEntryPointIO(TFunction& function);
...@@ -283,6 +284,9 @@ protected: ...@@ -283,6 +284,9 @@ protected:
void finish() override; // post-processing void finish() override; // post-processing
// Linkage symbol helpers
TIntermSymbol* findLinkageSymbol(TBuiltInVariable biType) const;
// Current state of parsing // Current state of parsing
struct TPragma contextPragma; struct TPragma contextPragma;
int loopNestingLevel; // 0 if outside all loops int loopNestingLevel; // 0 if outside all loops
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment