Commit e752f463 by steve-lunarg

HLSL: HS return is arrayed to match SPIR-V semantics

HLSL HS outputs a per ctrl point value, and the DS reads an array of that type. (It also has a per patch frequency). The per-ctrl-pt frequency is arrayed on just one side, as opposed to SPIR-V which is arrayed on both. To match semantics, the compiler creates an array behind the scenes and indexes it by invocation ID, assigning the HS return value to it.
parent 7afe1344
......@@ -27,7 +27,9 @@ vertices = 3
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:26 move second child to first child ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint})
0:26 indirect index ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID)
0:26 Function Call: @main(struct-VS_OUT-vf31[3]; ( temp structure{ temp 3-component vector of float cpoint})
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? Barrier ( temp void)
......@@ -43,7 +45,7 @@ vertices = 3
0:33 Function Definition: PCF( ( temp void)
0:33 Function Parameters:
0:? Linker Objects
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID)
......@@ -79,7 +81,9 @@ vertices = 3
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:26 move second child to first child ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint})
0:26 indirect index ( temp structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID)
0:26 Function Call: @main(struct-VS_OUT-vf31[3]; ( temp structure{ temp 3-component vector of float cpoint})
0:? 'ip' ( temp 3-element array of structure{ temp 3-component vector of float cpoint})
0:? Barrier ( temp void)
......@@ -95,18 +99,18 @@ vertices = 3
0:33 Function Definition: PCF( ( temp void)
0:33 Function Parameters:
0:? Linker Objects
0:? '@entryPointOutput' (layout( location=0) out structure{ temp 3-component vector of float cpoint})
0:? '@entryPointOutput' (layout( location=0) out 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'ip' (layout( location=0) in 3-element array of structure{ temp 3-component vector of float cpoint})
0:? 'InvocationId' ( in uint InvocationID)
// Module Version 10000
// Generated by (magic number): 80001
// Id's are bound by 51
// Id's are bound by 55
Capability Tessellation
1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450
EntryPoint TessellationControl 4 "main" 33 36 44
EntryPoint TessellationControl 4 "main" 33 37 39
ExecutionMode 4 OutputVertices 3
Name 4 "main"
Name 8 "VS_OUT"
......@@ -119,12 +123,12 @@ vertices = 3
Name 21 "output"
Name 31 "ip"
Name 33 "ip"
Name 36 "@entryPointOutput"
Name 37 "param"
Name 44 "InvocationId"
Name 37 "@entryPointOutput"
Name 39 "InvocationId"
Name 41 "param"
Decorate 33(ip) Location 0
Decorate 36(@entryPointOutput) Location 0
Decorate 44(InvocationId) BuiltIn InvocationId
Decorate 37(@entryPointOutput) Location 0
Decorate 39(InvocationId) BuiltIn InvocationId
2: TypeVoid
3: TypeFunction 2
6: TypeFloat 32
......@@ -142,33 +146,37 @@ vertices = 3
24: TypePointer Function 7(fvec3)
32: TypePointer Input 11
33(ip): 32(ptr) Variable Input
35: TypePointer Output 13(HS_OUT)
36(@entryPointOutput): 35(ptr) Variable Output
40: 9(int) Constant 2
41: 9(int) Constant 1
42: 9(int) Constant 0
43: TypePointer Input 9(int)
44(InvocationId): 43(ptr) Variable Input
46: TypeBool
35: TypeArray 13(HS_OUT) 10
36: TypePointer Output 35
37(@entryPointOutput): 36(ptr) Variable Output
38: TypePointer Input 9(int)
39(InvocationId): 38(ptr) Variable Input
44: TypePointer Output 13(HS_OUT)
46: 9(int) Constant 2
47: 9(int) Constant 1
48: 9(int) Constant 0
50: TypeBool
4(main): 2 Function None 3
5: Label
31(ip): 12(ptr) Variable Function
37(param): 12(ptr) Variable Function
41(param): 12(ptr) Variable Function
34: 11 Load 33(ip)
Store 31(ip) 34
38: 11 Load 31(ip)
Store 37(param) 38
39: 13(HS_OUT) FunctionCall 16(@main(struct-VS_OUT-vf31[3];) 37(param)
Store 36(@entryPointOutput) 39
ControlBarrier 40 41 42
45: 9(int) Load 44(InvocationId)
47: 46(bool) IEqual 45 23
SelectionMerge 49 None
BranchConditional 47 48 49
48: Label
50: 2 FunctionCall 18(PCF()
Branch 49
49: Label
40: 9(int) Load 39(InvocationId)
42: 11 Load 31(ip)
Store 41(param) 42
43: 13(HS_OUT) FunctionCall 16(@main(struct-VS_OUT-vf31[3];) 41(param)
45: 44(ptr) AccessChain 37(@entryPointOutput) 40
Store 45 43
ControlBarrier 46 47 48
49: 9(int) Load 39(InvocationId)
51: 50(bool) IEqual 49 23
SelectionMerge 53 None
BranchConditional 51 52 53
52: Label
54: 2 FunctionCall 18(PCF()
Branch 53
53: Label
Return
FunctionEnd
16(@main(struct-VS_OUT-vf31[3];): 13(HS_OUT) Function None 14
......
......@@ -1598,48 +1598,10 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l
return paramNodes;
}
//
// Do all special handling for the entry point, including wrapping
// the shader's entry point with the official entry point that will call it.
//
// The following:
//
// retType shaderEntryPoint(args...) // shader declared entry point
// { body }
//
// Becomes
//
// out retType ret;
// in iargs<that are input>...;
// out oargs<that are output> ...;
//
// void shaderEntryPoint() // synthesized, but official, entry point
// {
// args<that are input> = iargs...;
// ret = @shaderEntryPoint(args...);
// oargs = args<that are output>...;
// }
//
// The symbol table will still map the original entry point name to the
// the modified function and it's new name:
//
// symbol table: shaderEntryPoint -> @shaderEntryPoint
//
// Returns nullptr if no entry-point tree was built, otherwise, returns
// a subtree that creates the entry point.
//
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// if we aren't in the entry point, fix the IO as such and exit
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
remapNonEntryPointIO(userFunction);
return nullptr;
}
entryPointFunction = &userFunction; // needed in finish()
// entry point logic...
// Handle all [attrib] attribute for the shader entry point
void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// Handle entry-point function attributes
const TIntermAggregate* numThreads = attributes[EatNumThreads];
if (numThreads != nullptr) {
......@@ -1691,8 +1653,12 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
error(loc, "unsupported domain type", domainStr.c_str(), "");
}
if (! intermediate.setInputPrimitive(domain)) {
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
if (language == EShLangTessEvaluation) {
if (! intermediate.setInputPrimitive(domain))
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
} else {
if (! intermediate.setOutputPrimitive(domain))
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
}
}
}
......@@ -1770,6 +1736,52 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
}
}
}
}
//
// Do all special handling for the entry point, including wrapping
// the shader's entry point with the official entry point that will call it.
//
// The following:
//
// retType shaderEntryPoint(args...) // shader declared entry point
// { body }
//
// Becomes
//
// out retType ret;
// in iargs<that are input>...;
// out oargs<that are output> ...;
//
// void shaderEntryPoint() // synthesized, but official, entry point
// {
// args<that are input> = iargs...;
// ret = @shaderEntryPoint(args...);
// oargs = args<that are output>...;
// }
//
// The symbol table will still map the original entry point name to the
// the modified function and it's new name:
//
// symbol table: shaderEntryPoint -> @shaderEntryPoint
//
// Returns nullptr if no entry-point tree was built, otherwise, returns
// a subtree that creates the entry point.
//
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// if we aren't in the entry point, fix the IO as such and exit
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
remapNonEntryPointIO(userFunction);
return nullptr;
}
entryPointFunction = &userFunction; // needed in finish()
// Handle entry point attributes
handleEntryPointAttributes(loc, userFunction, attributes);
// entry point logic...
// Move parameters and return value to shader in/out
TVariable* entryPointOutput; // gets created in remapEntryPointIO
......@@ -1838,10 +1850,37 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
currentCaller = userFunction.getMangledName();
// Return value
if (entryPointOutput)
intermediate.growAggregate(synthBody, handleAssign(loc, EOpAssign,
intermediate.addSymbol(*entryPointOutput), callReturn));
else
if (entryPointOutput) {
TIntermTyped* returnAssign;
if (language == EShLangTessControl) {
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
// If there is no user declared invocation ID, we must make one.
if (invocationIdSym == nullptr) {
TType invocationIdType(EbtUint, EvqIn, 1);
TString* invocationIdName = NewPoolTString("InvocationId");
invocationIdType.getQualifier().builtIn = EbvInvocationId;
TVariable* variable = makeInternalVariable(*invocationIdName, invocationIdType);
globalQualifierFix(loc, variable->getWritableType().getQualifier());
trackLinkage(*variable);
invocationIdSym = intermediate.addSymbol(*variable);
}
TIntermTyped* element = intermediate.addIndex(EOpIndexIndirect, intermediate.addSymbol(*entryPointOutput),
invocationIdSym, loc);
element->setType(callReturn->getType());
returnAssign = handleAssign(loc, EOpAssign, element, callReturn);
} else {
returnAssign = handleAssign(loc, EOpAssign, intermediate.addSymbol(*entryPointOutput), callReturn);
}
intermediate.growAggregate(synthBody, returnAssign);
} else
intermediate.growAggregate(synthBody, callReturn);
// Output copies
......@@ -1914,10 +1953,29 @@ void HlslParseContext::remapEntryPointIO(TFunction& function, TVariable*& return
};
// return value is actually a shader-scoped output (out)
if (function.getType().getBasicType() == EbtVoid)
if (function.getType().getBasicType() == EbtVoid) {
returnValue = nullptr;
else
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
} else {
if (language == EShLangTessControl) {
// tessellation evaluation in HLSL writes a per-ctrl-pt value, but it needs to be an
// array in SPIR-V semantics. We'll write to it indexed by invocation ID.
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
TType outputType;
outputType.shallowCopy(function.getType());
// vertices has necessarily already been set when handling entry point attributes.
TArraySizes arraySizes;
arraySizes.addInnerSize(intermediate.getVertices());
outputType.newArraySizes(arraySizes);
clearUniformInputOutput(function.getWritableType().getQualifier());
returnValue = makeIoVariable("@entryPointOutput", outputType, EvqVaryingOut);
} else {
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
}
}
// parameters are actually shader-scoped inputs and outputs (in or out)
for (int i = 0; i < function.getParamCount(); i++) {
......@@ -7410,6 +7468,17 @@ void HlslParseContext::clearUniformInputOutput(TQualifier& qualifier)
correctUniform(qualifier);
}
// Return a symbol for the linkage variable of the given TBuiltInVariable type
TIntermSymbol* HlslParseContext::findLinkageSymbol(TBuiltInVariable biType) const
{
const auto it = builtInLinkageSymbols.find(biType);
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
return nullptr;
return intermediate.addSymbol(*it->second->getAsVariable());
}
// Add patch constant function invocation
void HlslParseContext::addPatchConstantInvocation()
{
......@@ -7481,15 +7550,6 @@ void HlslParseContext::addPatchConstantInvocation()
}
};
// Return a symbol for the linkage variable of the given TBuiltInVariable type
const auto findLinkageSymbol = [this](TBuiltInVariable biType) -> TIntermSymbol* {
const auto it = builtInLinkageSymbols.find(biType);
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
return nullptr;
return intermediate.addSymbol(*it->second->getAsVariable());
};
const auto isPerCtrlPt = [this](const TType& type) {
// TODO: this is not sufficient to reject all such cases in malformed shaders.
return type.isArray() && !type.isRuntimeSizedArray();
......
......@@ -80,6 +80,7 @@ public:
void handleFunctionDeclarator(const TSourceLoc&, TFunction& function, bool prototype);
TIntermAggregate* handleFunctionDefinition(const TSourceLoc&, TFunction&, const TAttributeMap&, TIntermNode*& entryPointTree);
TIntermNode* transformEntryPoint(const TSourceLoc&, TFunction&, const TAttributeMap&);
void handleEntryPointAttributes(const TSourceLoc&, TFunction&, const TAttributeMap&);
void handleFunctionBody(const TSourceLoc&, TFunction&, TIntermNode* functionBody, TIntermNode*& node);
void remapEntryPointIO(TFunction& function, TVariable*& returnValue, TVector<TVariable*>& inputs, TVector<TVariable*>& outputs);
void remapNonEntryPointIO(TFunction& function);
......@@ -283,6 +284,9 @@ protected:
void finish() override; // post-processing
// Linkage symbol helpers
TIntermSymbol* findLinkageSymbol(TBuiltInVariable biType) const;
// Current state of parsing
struct TPragma contextPragma;
int loopNestingLevel; // 0 if outside all loops
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment