Commit 41436ad2 by John Kessenich

Link/SPV: Correct symbol IDs on merging ASTs to a single coherent space

This is one step in providing full linker functionality for creating correct SPIR-V from multiple compilation units for the same stage. (This was the only remaining "hard" part. The rest should be simple.)
parent e7f9caea
...@@ -198,7 +198,7 @@ gl_FragCoord origin is upper left ...@@ -198,7 +198,7 @@ gl_FragCoord origin is upper left
// Module Version 10000 // Module Version 10000
// Generated by (magic number): 80007 // Generated by (magic number): 80007
// Id's are bound by 71 // Id's are bound by 70
Capability Shader Capability Shader
1: ExtInstImport "GLSL.std.450" 1: ExtInstImport "GLSL.std.450"
...@@ -214,24 +214,24 @@ gl_FragCoord origin is upper left ...@@ -214,24 +214,24 @@ gl_FragCoord origin is upper left
Name 32 "b" Name 32 "b"
Name 33 "i" Name 33 "i"
Name 39 "c" Name 39 "c"
Name 54 "s2D" Name 53 "s2D"
Name 63 "bnameRuntime" Name 62 "bnameRuntime"
MemberName 63(bnameRuntime) 0 "r" MemberName 62(bnameRuntime) 0 "r"
Name 65 "" Name 64 ""
Name 68 "bnameImplicit" Name 67 "bnameImplicit"
MemberName 68(bnameImplicit) 0 "m" MemberName 67(bnameImplicit) 0 "m"
Name 70 "" Name 69 ""
Decorate 12(color) Location 0 Decorate 12(color) Location 0
Decorate 54(s2D) DescriptorSet 0 Decorate 53(s2D) DescriptorSet 0
Decorate 54(s2D) Binding 1 Decorate 53(s2D) Binding 1
Decorate 62 ArrayStride 4 Decorate 61 ArrayStride 4
MemberDecorate 63(bnameRuntime) 0 Offset 0 MemberDecorate 62(bnameRuntime) 0 Offset 0
Decorate 63(bnameRuntime) BufferBlock Decorate 62(bnameRuntime) BufferBlock
Decorate 65 DescriptorSet 0 Decorate 64 DescriptorSet 0
Decorate 67 ArrayStride 4 Decorate 66 ArrayStride 4
MemberDecorate 68(bnameImplicit) 0 Offset 0 MemberDecorate 67(bnameImplicit) 0 Offset 0
Decorate 68(bnameImplicit) BufferBlock Decorate 67(bnameImplicit) BufferBlock
Decorate 70 DescriptorSet 0 Decorate 69 DescriptorSet 0
2: TypeVoid 2: TypeVoid
3: TypeFunction 2 3: TypeFunction 2
6: TypeFloat 32 6: TypeFloat 32
...@@ -263,24 +263,23 @@ gl_FragCoord origin is upper left ...@@ -263,24 +263,23 @@ gl_FragCoord origin is upper left
39(c): 38(ptr) Variable Private 39(c): 38(ptr) Variable Private
40: 14(int) Constant 3 40: 14(int) Constant 3
42: 14(int) Constant 2 42: 14(int) Constant 2
43: TypePointer Output 6(float) 44: 14(int) Constant 9
45: 14(int) Constant 9 50: TypeImage 6(float) 2D sampled format:Unknown
51: TypeImage 6(float) 2D sampled format:Unknown 51: TypeSampledImage 50
52: TypeSampledImage 51 52: TypePointer UniformConstant 51
53: TypePointer UniformConstant 52 53(s2D): 52(ptr) Variable UniformConstant
54(s2D): 53(ptr) Variable UniformConstant 55: TypeVector 6(float) 2
56: TypeVector 6(float) 2 56: 6(float) Constant 1056964608
57: 6(float) Constant 1056964608 57: 55(fvec2) ConstantComposite 56 56
58: 56(fvec2) ConstantComposite 57 57 61: TypeRuntimeArray 6(float)
62: TypeRuntimeArray 6(float) 62(bnameRuntime): TypeStruct 61
63(bnameRuntime): TypeStruct 62 63: TypePointer Uniform 62(bnameRuntime)
64: TypePointer Uniform 63(bnameRuntime) 64: 63(ptr) Variable Uniform
65: 64(ptr) Variable Uniform 65: 15(int) Constant 4
66: 15(int) Constant 4 66: TypeArray 6(float) 65
67: TypeArray 6(float) 66 67(bnameImplicit): TypeStruct 66
68(bnameImplicit): TypeStruct 67 68: TypePointer Uniform 67(bnameImplicit)
69: TypePointer Uniform 68(bnameImplicit) 69: 68(ptr) Variable Uniform
70: 69(ptr) Variable Uniform
4(main): 2 Function None 3 4(main): 2 Function None 3
5: Label 5: Label
13: 7(fvec4) FunctionCall 9(getColor() 13: 7(fvec4) FunctionCall 9(getColor()
...@@ -298,18 +297,18 @@ gl_FragCoord origin is upper left ...@@ -298,18 +297,18 @@ gl_FragCoord origin is upper left
FunctionEnd FunctionEnd
9(getColor(): 7(fvec4) Function None 8 9(getColor(): 7(fvec4) Function None 8
10: Label 10: Label
44: 43(ptr) AccessChain 12(color) 42 43: 22(ptr) AccessChain 19(a1) 42
Store 44 21 Store 43 21
46: 22(ptr) AccessChain 19(a1) 45 45: 22(ptr) AccessChain 27(a2) 44
Store 45 21
46: 22(ptr) AccessChain 32(b) 42
Store 46 21 Store 46 21
47: 22(ptr) AccessChain 27(a2) 42 47: 22(ptr) AccessChain 39(c) 40
Store 47 21 Store 47 21
48: 22(ptr) AccessChain 32(b) 40 48: 14(int) Load 33(i)
Store 48 21 49: 22(ptr) AccessChain 39(c) 48
49: 37 Load 39(c) Store 49 21
50: 22(ptr) AccessChain 32(b) 49 54: 51 Load 53(s2D)
Store 50 21 58: 7(fvec4) ImageSampleImplicitLod 54 57
55: 52 Load 54(s2D) ReturnValue 58
59: 7(fvec4) ImageSampleImplicitLod 55 58
ReturnValue 59
FunctionEnd FunctionEnd
#version 460
float f;
float a1;
float foo();
void main()
{
f = 10;
float g = foo();
f += g;
f += gl_FragCoord.y;
}
\ No newline at end of file
#version 410
// a different version number makes different id's for the same shared symbol
float a2;
float f;
float bar();
float foo()
{
float h2 = 2 * f;
float g2 = bar();
return h2 + g2 + gl_FragCoord.y;
}
\ No newline at end of file
#version 460
float f;
float h3 = 3.0;
float bar()
{
h3 *= f;
float g3 = 2 * h3;
return h3 + g3 + gl_FragCoord.y;
}
...@@ -1164,6 +1164,7 @@ public: ...@@ -1164,6 +1164,7 @@ public:
constSubtree(nullptr) constSubtree(nullptr)
{ name = n; } { name = n; }
virtual int getId() const { return id; } virtual int getId() const { return id; }
virtual void changeId(int i) { id = i; }
virtual const TString& getName() const { return name; } virtual const TString& getName() const { return name; }
virtual void traverse(TIntermTraverser*); virtual void traverse(TIntermTraverser*);
virtual TIntermSymbol* getAsSymbolNode() { return this; } virtual TIntermSymbol* getAsSymbolNode() { return this; }
......
...@@ -182,22 +182,132 @@ void TIntermediate::merge(TInfoSink& infoSink, TIntermediate& unit) ...@@ -182,22 +182,132 @@ void TIntermediate::merge(TInfoSink& infoSink, TIntermediate& unit)
} }
// Getting this far means we have two existing trees to merge... // Getting this far means we have two existing trees to merge...
mergeTree(infoSink, unit);
version = std::max(version, unit.version); version = std::max(version, unit.version);
requestedExtensions.insert(unit.requestedExtensions.begin(), unit.requestedExtensions.end()); requestedExtensions.insert(unit.requestedExtensions.begin(), unit.requestedExtensions.end());
ioAccessed.insert(unit.ioAccessed.begin(), unit.ioAccessed.end());
}
//
// Merge the 'unit' AST into 'this' AST.
// That includes rationalizing the unique IDs, which were set up independently,
// and might have overlaps that are not the same symbol, or might have different
// IDs for what should be the same shared symbol.
//
void TIntermediate::mergeTree(TInfoSink& infoSink, TIntermediate& unit)
{
// Get the top-level globals of each unit // Get the top-level globals of each unit
TIntermSequence& globals = treeRoot->getAsAggregate()->getSequence(); TIntermSequence& globals = treeRoot->getAsAggregate()->getSequence();
TIntermSequence& unitGlobals = unit.treeRoot->getAsAggregate()->getSequence(); TIntermSequence& unitGlobals = unit.treeRoot->getAsAggregate()->getSequence();
// Get the linker-object lists // Get the linker-object lists
TIntermSequence& linkerObjects = findLinkerObjects(); TIntermSequence& linkerObjects = findLinkerObjects()->getSequence();
TIntermSequence& unitLinkerObjects = unit.findLinkerObjects(); const TIntermSequence& unitLinkerObjects = unit.findLinkerObjects()->getSequence();
// Map by global name to unique ID to rationalize the same object having
// differing IDs in different trees.
TMap<TString, int> idMap;
int maxId;
seedIdMap(idMap, maxId);
remapIds(idMap, maxId + 1, unit);
mergeBodies(infoSink, globals, unitGlobals); mergeBodies(infoSink, globals, unitGlobals);
mergeLinkerObjects(infoSink, linkerObjects, unitLinkerObjects); mergeLinkerObjects(infoSink, linkerObjects, unitLinkerObjects);
}
ioAccessed.insert(unit.ioAccessed.begin(), unit.ioAccessed.end()); // Traverser that seeds an ID map with all built-ins, and tracks the
// maximum ID used.
// (It would be nice to put this in a function, but that causes warnings
// on having no bodies for the copy-constructor/operator=.)
class TBuiltInIdTraverser : public TIntermTraverser {
public:
TBuiltInIdTraverser(TMap<TString, int>& idMap) : idMap(idMap), maxId(0) { }
// If it's a built in, add it to the map.
// Track the max ID.
virtual void visitSymbol(TIntermSymbol* symbol)
{
const TQualifier& qualifier = symbol->getType().getQualifier();
if (qualifier.builtIn != EbvNone)
idMap[symbol->getName()] = symbol->getId();
maxId = std::max(maxId, symbol->getId());
}
int getMaxId() const { return maxId; }
protected:
TBuiltInIdTraverser(TBuiltInIdTraverser&);
TBuiltInIdTraverser& operator=(TBuiltInIdTraverser&);
TMap<TString, int>& idMap;
int maxId;
};
// Traverser that seeds an ID map with non-builtin globals.
// (It would be nice to put this in a function, but that causes warnings
// on having no bodies for the copy-constructor/operator=.)
class TUserIdTraverser : public TIntermTraverser {
public:
TUserIdTraverser(TMap<TString, int>& idMap) : idMap(idMap) { }
// If its a non-built-in global, add it to the map.
virtual void visitSymbol(TIntermSymbol* symbol)
{
const TQualifier& qualifier = symbol->getType().getQualifier();
if (qualifier.storage == EvqGlobal && qualifier.builtIn == EbvNone)
idMap[symbol->getName()] = symbol->getId();
}
protected:
TUserIdTraverser(TUserIdTraverser&);
TUserIdTraverser& operator=(TUserIdTraverser&);
TMap<TString, int>& idMap; // over biggest id
};
// Initialize the the ID map with what we know of 'this' AST.
void TIntermediate::seedIdMap(TMap<TString, int>& idMap, int& maxId)
{
// all built-ins everywhere need to align on IDs and contribute to the max ID
TBuiltInIdTraverser builtInIdTraverser(idMap);
treeRoot->traverse(&builtInIdTraverser);
maxId = builtInIdTraverser.getMaxId();
// user variables in the linker object list need to align on ids
TUserIdTraverser userIdTraverser(idMap);
findLinkerObjects()->traverse(&userIdTraverser);
}
// Traverser to map an AST ID to what was known from the seeding AST.
// (It would be nice to put this in a function, but that causes warnings
// on having no bodies for the copy-constructor/operator=.)
class TRemapIdTraverser : public TIntermTraverser {
public:
TRemapIdTraverser(const TMap<TString, int>& idMap, int idShift) : idMap(idMap), idShift(idShift) { }
// Do the mapping:
// - if the same symbol, adopt the 'this' ID
// - otherwise, ensure a unique ID by shifting to a new space
virtual void visitSymbol(TIntermSymbol* symbol)
{
const TQualifier& qualifier = symbol->getType().getQualifier();
bool remapped = false;
if (qualifier.storage == EvqGlobal || qualifier.builtIn != EbvNone) {
auto it = idMap.find(symbol->getName());
if (it != idMap.end()) {
symbol->changeId(it->second);
remapped = true;
}
}
if (!remapped)
symbol->changeId(symbol->getId() + idShift);
}
protected:
TRemapIdTraverser(TRemapIdTraverser&);
TRemapIdTraverser& operator=(TRemapIdTraverser&);
const TMap<TString, int>& idMap;
int idShift;
};
void TIntermediate::remapIds(const TMap<TString, int>& idMap, int idShift, TIntermediate& unit)
{
// Remap all IDs to either share or be unique, as dictated by the idMap and idShift.
TRemapIdTraverser idTraverser(idMap, idShift);
unit.getTreeRoot()->traverse(&idTraverser);
} }
// //
...@@ -699,7 +809,7 @@ void TIntermediate::inOutLocationCheck(TInfoSink& infoSink) ...@@ -699,7 +809,7 @@ void TIntermediate::inOutLocationCheck(TInfoSink& infoSink)
// TODO: linker functionality: location collision checking // TODO: linker functionality: location collision checking
TIntermSequence& linkObjects = findLinkerObjects(); TIntermSequence& linkObjects = findLinkerObjects()->getSequence();
for (size_t i = 0; i < linkObjects.size(); ++i) { for (size_t i = 0; i < linkObjects.size(); ++i) {
const TType& type = linkObjects[i]->getAsTyped()->getType(); const TType& type = linkObjects[i]->getAsTyped()->getType();
const TQualifier& qualifier = type.getQualifier(); const TQualifier& qualifier = type.getQualifier();
...@@ -718,7 +828,7 @@ void TIntermediate::inOutLocationCheck(TInfoSink& infoSink) ...@@ -718,7 +828,7 @@ void TIntermediate::inOutLocationCheck(TInfoSink& infoSink)
} }
} }
TIntermSequence& TIntermediate::findLinkerObjects() const TIntermAggregate* TIntermediate::findLinkerObjects() const
{ {
// Get the top-level globals // Get the top-level globals
TIntermSequence& globals = treeRoot->getAsAggregate()->getSequence(); TIntermSequence& globals = treeRoot->getAsAggregate()->getSequence();
...@@ -726,7 +836,7 @@ TIntermSequence& TIntermediate::findLinkerObjects() const ...@@ -726,7 +836,7 @@ TIntermSequence& TIntermediate::findLinkerObjects() const
// Get the last member of the sequences, expected to be the linker-object lists // Get the last member of the sequences, expected to be the linker-object lists
assert(globals.back()->getAsAggregate()->getOp() == EOpLinkerObjects); assert(globals.back()->getAsAggregate()->getOp() == EOpLinkerObjects);
return globals.back()->getAsAggregate()->getSequence(); return globals.back()->getAsAggregate();
} }
// See if a variable was both a user-declared output and used. // See if a variable was both a user-declared output and used.
...@@ -734,7 +844,7 @@ TIntermSequence& TIntermediate::findLinkerObjects() const ...@@ -734,7 +844,7 @@ TIntermSequence& TIntermediate::findLinkerObjects() const
// is more useful, and perhaps the spec should be changed to reflect that. // is more useful, and perhaps the spec should be changed to reflect that.
bool TIntermediate::userOutputUsed() const bool TIntermediate::userOutputUsed() const
{ {
const TIntermSequence& linkerObjects = findLinkerObjects(); const TIntermSequence& linkerObjects = findLinkerObjects()->getSequence();
bool found = false; bool found = false;
for (size_t i = 0; i < linkerObjects.size(); ++i) { for (size_t i = 0; i < linkerObjects.size(); ++i) {
......
...@@ -645,6 +645,9 @@ protected: ...@@ -645,6 +645,9 @@ protected:
TIntermSymbol* addSymbol(int Id, const TString&, const TType&, const TConstUnionArray&, TIntermTyped* subtree, const TSourceLoc&); TIntermSymbol* addSymbol(int Id, const TString&, const TType&, const TConstUnionArray&, TIntermTyped* subtree, const TSourceLoc&);
void error(TInfoSink& infoSink, const char*); void error(TInfoSink& infoSink, const char*);
void warn(TInfoSink& infoSink, const char*); void warn(TInfoSink& infoSink, const char*);
void mergeTree(TInfoSink&, TIntermediate&);
void seedIdMap(TMap<TString, int>& idMap, int& maxId);
void remapIds(const TMap<TString, int>& idMap, int idShift, TIntermediate&);
void mergeBodies(TInfoSink&, TIntermSequence& globals, const TIntermSequence& unitGlobals); void mergeBodies(TInfoSink&, TIntermSequence& globals, const TIntermSequence& unitGlobals);
void mergeLinkerObjects(TInfoSink&, TIntermSequence& linkerObjects, const TIntermSequence& unitLinkerObjects); void mergeLinkerObjects(TInfoSink&, TIntermSequence& linkerObjects, const TIntermSequence& unitLinkerObjects);
void mergeImplicitArraySizes(TType&, const TType&); void mergeImplicitArraySizes(TType&, const TType&);
...@@ -652,7 +655,7 @@ protected: ...@@ -652,7 +655,7 @@ protected:
void checkCallGraphCycles(TInfoSink&); void checkCallGraphCycles(TInfoSink&);
void checkCallGraphBodies(TInfoSink&, bool keepUncalled); void checkCallGraphBodies(TInfoSink&, bool keepUncalled);
void inOutLocationCheck(TInfoSink&); void inOutLocationCheck(TInfoSink&);
TIntermSequence& findLinkerObjects() const; TIntermAggregate* findLinkerObjects() const;
bool userOutputUsed() const; bool userOutputUsed() const;
bool isSpecializationOperation(const TIntermOperator&) const; bool isSpecializationOperation(const TIntermOperator&) const;
bool isNonuniformPropagating(TOperator) const; bool isNonuniformPropagating(TOperator) const;
......
...@@ -106,6 +106,7 @@ INSTANTIATE_TEST_CASE_P( ...@@ -106,6 +106,7 @@ INSTANTIATE_TEST_CASE_P(
Glsl, LinkTestVulkan, Glsl, LinkTestVulkan,
::testing::ValuesIn(std::vector<std::vector<std::string>>({ ::testing::ValuesIn(std::vector<std::vector<std::string>>({
{"link1.vk.frag", "link2.vk.frag"}, {"link1.vk.frag", "link2.vk.frag"},
{"spv.unit1.frag", "spv.unit2.frag", "spv.unit3.frag"},
})), })),
); );
// clang-format on // clang-format on
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment