Commit 112e2858 by Arseny Kapoulkine

SPIRV: Simplify matrix->matrix constructor

When constructing a matrix from another matrix with smaller dimensions, there's no need to extract the scalars out of columns and rebuild the resulting matrix from scalars - instead, we can just construct shorter vectors with OpShuffle and combine them to the final result. This keeps the common casts such as mat3(mat4) in vector registers, which may improve performance for some GPUs, and cleans up output of translation tools like SPIRV-Cross. Fixes #1412.
parent cd57b4ba
...@@ -2031,83 +2031,113 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>& ...@@ -2031,83 +2031,113 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&
Instruction* instr = module.getInstruction(componentTypeId); Instruction* instr = module.getInstruction(componentTypeId);
Id bitCount = instr->getIdOperand(0); Id bitCount = instr->getIdOperand(0);
// Will use a two step process if (isMatrix(sources[0]) && getNumColumns(sources[0]) >= numCols && getNumRows(sources[0]) >= numRows) {
// 1. make a compile-time 2D array of values // To truncate the matrix to a smaller number of rows/columns, we need to:
// 2. construct a matrix from that array // 1. For each column, extract the column and truncate it to the required size using shuffle
// 2. Assemble the resulting matrix from all columns
// Step 1. Id matrix = sources[0];
Id columnTypeId = getContainedTypeId(resultTypeId);
// initialize the array to the identity matrix Id sourceColumnTypeId = getContainedTypeId(getTypeId(matrix));
Id ids[maxMatrixSize][maxMatrixSize];
Id one = (bitCount == 64 ? makeDoubleConstant(1.0) : makeFloatConstant(1.0)); std::vector<unsigned> channels;
Id zero = (bitCount == 64 ? makeDoubleConstant(0.0) : makeFloatConstant(0.0)); for (int row = 0; row < numRows; ++row) {
for (int col = 0; col < 4; ++col) { channels.push_back(row);
for (int row = 0; row < 4; ++row) {
if (col == row)
ids[col][row] = one;
else
ids[col][row] = zero;
} }
}
// modify components as dictated by the arguments std::vector<Id> matrixColumns;
if (sources.size() == 1 && isScalar(sources[0])) { for (int col = 0; col < numCols; ++col) {
// a single scalar; resets the diagonals
for (int col = 0; col < 4; ++col)
ids[col][col] = sources[0];
} else if (isMatrix(sources[0])) {
// constructing from another matrix; copy over the parts that exist in both the argument and constructee
Id matrix = sources[0];
int minCols = std::min(numCols, getNumColumns(matrix));
int minRows = std::min(numRows, getNumRows(matrix));
for (int col = 0; col < minCols; ++col) {
std::vector<unsigned> indexes; std::vector<unsigned> indexes;
indexes.push_back(col); indexes.push_back(col);
for (int row = 0; row < minRows; ++row) { Id colv = createCompositeExtract(matrix, sourceColumnTypeId, indexes);
indexes.push_back(row); setPrecision(colv, precision);
ids[col][row] = createCompositeExtract(matrix, componentTypeId, indexes);
indexes.pop_back(); if (numRows != getNumRows(matrix)) {
setPrecision(ids[col][row], precision); matrixColumns.push_back(createRvalueSwizzle(precision, columnTypeId, colv, channels));
} else {
matrixColumns.push_back(colv);
} }
} }
return setPrecision(createCompositeConstruct(resultTypeId, matrixColumns), precision);
} else { } else {
// fill in the matrix in column-major order with whatever argument components are available // Will use a two step process
int row = 0; // 1. make a compile-time 2D array of values
int col = 0; // 2. construct a matrix from that array
// Step 1.
// initialize the array to the identity matrix
Id ids[maxMatrixSize][maxMatrixSize];
Id one = (bitCount == 64 ? makeDoubleConstant(1.0) : makeFloatConstant(1.0));
Id zero = (bitCount == 64 ? makeDoubleConstant(0.0) : makeFloatConstant(0.0));
for (int col = 0; col < 4; ++col) {
for (int row = 0; row < 4; ++row) {
if (col == row)
ids[col][row] = one;
else
ids[col][row] = zero;
}
}
for (int arg = 0; arg < (int)sources.size(); ++arg) { // modify components as dictated by the arguments
Id argComp = sources[arg]; if (sources.size() == 1 && isScalar(sources[0])) {
for (int comp = 0; comp < getNumComponents(sources[arg]); ++comp) { // a single scalar; resets the diagonals
if (getNumComponents(sources[arg]) > 1) { for (int col = 0; col < 4; ++col)
argComp = createCompositeExtract(sources[arg], componentTypeId, comp); ids[col][col] = sources[0];
setPrecision(argComp, precision); } else if (isMatrix(sources[0])) {
// constructing from another matrix; copy over the parts that exist in both the argument and constructee
Id matrix = sources[0];
int minCols = std::min(numCols, getNumColumns(matrix));
int minRows = std::min(numRows, getNumRows(matrix));
for (int col = 0; col < minCols; ++col) {
std::vector<unsigned> indexes;
indexes.push_back(col);
for (int row = 0; row < minRows; ++row) {
indexes.push_back(row);
ids[col][row] = createCompositeExtract(matrix, componentTypeId, indexes);
indexes.pop_back();
setPrecision(ids[col][row], precision);
} }
ids[col][row++] = argComp; }
if (row == numRows) { } else {
row = 0; // fill in the matrix in column-major order with whatever argument components are available
col++; int row = 0;
int col = 0;
for (int arg = 0; arg < (int)sources.size(); ++arg) {
Id argComp = sources[arg];
for (int comp = 0; comp < getNumComponents(sources[arg]); ++comp) {
if (getNumComponents(sources[arg]) > 1) {
argComp = createCompositeExtract(sources[arg], componentTypeId, comp);
setPrecision(argComp, precision);
}
ids[col][row++] = argComp;
if (row == numRows) {
row = 0;
col++;
}
} }
} }
} }
}
// Step 2: Construct a matrix from that array. // Step 2: Construct a matrix from that array.
// First make the column vectors, then make the matrix. // First make the column vectors, then make the matrix.
// make the column vectors
Id columnTypeId = getContainedTypeId(resultTypeId);
std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
std::vector<Id> vectorComponents;
for (int row = 0; row < numRows; ++row)
vectorComponents.push_back(ids[col][row]);
Id column = createCompositeConstruct(columnTypeId, vectorComponents);
setPrecision(column, precision);
matrixColumns.push_back(column);
}
// make the column vectors // make the matrix
Id columnTypeId = getContainedTypeId(resultTypeId); return setPrecision(createCompositeConstruct(resultTypeId, matrixColumns), precision);
std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
std::vector<Id> vectorComponents;
for (int row = 0; row < numRows; ++row)
vectorComponents.push_back(ids[col][row]);
Id column = createCompositeConstruct(columnTypeId, vectorComponents);
setPrecision(column, precision);
matrixColumns.push_back(column);
} }
// make the matrix
return setPrecision(createCompositeConstruct(resultTypeId, matrixColumns), precision);
} }
// Comments in header // Comments in header
......
...@@ -251,12 +251,12 @@ Shader version: 500 ...@@ -251,12 +251,12 @@ Shader version: 500
// Module Version 10000 // Module Version 10000
// Generated by (magic number): 80007 // Generated by (magic number): 80007
// Id's are bound by 106 // Id's are bound by 93
Capability Shader Capability Shader
1: ExtInstImport "GLSL.std.450" 1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450 MemoryModel Logical GLSL450
EntryPoint Vertex 4 "main" 87 91 99 103 EntryPoint Vertex 4 "main" 74 78 86 90
Source HLSL 500 Source HLSL 500
Name 4 "main" Name 4 "main"
Name 9 "VS_INPUT" Name 9 "VS_INPUT"
...@@ -274,13 +274,13 @@ Shader version: 500 ...@@ -274,13 +274,13 @@ Shader version: 500
MemberName 28(C) 1 "View" MemberName 28(C) 1 "View"
MemberName 28(C) 2 "Projection" MemberName 28(C) 2 "Projection"
Name 30 "" Name 30 ""
Name 85 "input" Name 72 "input"
Name 87 "input.Pos" Name 74 "input.Pos"
Name 91 "input.Norm" Name 78 "input.Norm"
Name 94 "flattenTemp" Name 81 "flattenTemp"
Name 95 "param" Name 82 "param"
Name 99 "@entryPointOutput.Pos" Name 86 "@entryPointOutput.Pos"
Name 103 "@entryPointOutput.Norm" Name 90 "@entryPointOutput.Norm"
MemberDecorate 28(C) 0 RowMajor MemberDecorate 28(C) 0 RowMajor
MemberDecorate 28(C) 0 Offset 0 MemberDecorate 28(C) 0 Offset 0
MemberDecorate 28(C) 0 MatrixStride 16 MemberDecorate 28(C) 0 MatrixStride 16
...@@ -293,10 +293,10 @@ Shader version: 500 ...@@ -293,10 +293,10 @@ Shader version: 500
Decorate 28(C) Block Decorate 28(C) Block
Decorate 30 DescriptorSet 0 Decorate 30 DescriptorSet 0
Decorate 30 Binding 0 Decorate 30 Binding 0
Decorate 87(input.Pos) Location 0 Decorate 74(input.Pos) Location 0
Decorate 91(input.Norm) Location 1 Decorate 78(input.Norm) Location 1
Decorate 99(@entryPointOutput.Pos) BuiltIn Position Decorate 86(@entryPointOutput.Pos) BuiltIn Position
Decorate 103(@entryPointOutput.Norm) Location 0 Decorate 90(@entryPointOutput.Norm) Location 0
2: TypeVoid 2: TypeVoid
3: TypeFunction 2 3: TypeFunction 2
6: TypeFloat 32 6: TypeFloat 32
...@@ -324,37 +324,36 @@ Shader version: 500 ...@@ -324,37 +324,36 @@ Shader version: 500
39: 16(int) Constant 1 39: 16(int) Constant 1
46: 16(int) Constant 2 46: 16(int) Constant 2
55: TypeMatrix 7(fvec4) 3 55: TypeMatrix 7(fvec4) 3
56: 6(float) Constant 1065353216 60: TypePointer Function 8(fvec3)
73: TypePointer Function 8(fvec3) 73: TypePointer Input 7(fvec4)
86: TypePointer Input 7(fvec4) 74(input.Pos): 73(ptr) Variable Input
87(input.Pos): 86(ptr) Variable Input 77: TypePointer Input 8(fvec3)
90: TypePointer Input 8(fvec3) 78(input.Norm): 77(ptr) Variable Input
91(input.Norm): 90(ptr) Variable Input 85: TypePointer Output 7(fvec4)
98: TypePointer Output 7(fvec4) 86(@entryPointOutput.Pos): 85(ptr) Variable Output
99(@entryPointOutput.Pos): 98(ptr) Variable Output 89: TypePointer Output 8(fvec3)
102: TypePointer Output 8(fvec3) 90(@entryPointOutput.Norm): 89(ptr) Variable Output
103(@entryPointOutput.Norm): 102(ptr) Variable Output
4(main): 2 Function None 3 4(main): 2 Function None 3
5: Label 5: Label
85(input): 10(ptr) Variable Function 72(input): 10(ptr) Variable Function
94(flattenTemp): 20(ptr) Variable Function 81(flattenTemp): 20(ptr) Variable Function
95(param): 10(ptr) Variable Function 82(param): 10(ptr) Variable Function
88: 7(fvec4) Load 87(input.Pos) 75: 7(fvec4) Load 74(input.Pos)
89: 34(ptr) AccessChain 85(input) 26 76: 34(ptr) AccessChain 72(input) 26
Store 89 88 Store 76 75
92: 8(fvec3) Load 91(input.Norm) 79: 8(fvec3) Load 78(input.Norm)
93: 73(ptr) AccessChain 85(input) 39 80: 60(ptr) AccessChain 72(input) 39
Store 93 92 Store 80 79
96: 9(VS_INPUT) Load 85(input) 83: 9(VS_INPUT) Load 72(input)
Store 95(param) 96 Store 82(param) 83
97:11(PS_INPUT) FunctionCall 14(@main(struct-VS_INPUT-vf4-vf31;) 95(param) 84:11(PS_INPUT) FunctionCall 14(@main(struct-VS_INPUT-vf4-vf31;) 82(param)
Store 94(flattenTemp) 97 Store 81(flattenTemp) 84
100: 34(ptr) AccessChain 94(flattenTemp) 26 87: 34(ptr) AccessChain 81(flattenTemp) 26
101: 7(fvec4) Load 100 88: 7(fvec4) Load 87
Store 99(@entryPointOutput.Pos) 101 Store 86(@entryPointOutput.Pos) 88
104: 73(ptr) AccessChain 94(flattenTemp) 39 91: 60(ptr) AccessChain 81(flattenTemp) 39
105: 8(fvec3) Load 104 92: 8(fvec3) Load 91
Store 103(@entryPointOutput.Norm) 105 Store 90(@entryPointOutput.Norm) 92
Return Return
FunctionEnd FunctionEnd
14(@main(struct-VS_INPUT-vf4-vf31;):11(PS_INPUT) Function None 12 14(@main(struct-VS_INPUT-vf4-vf31;):11(PS_INPUT) Function None 12
...@@ -387,31 +386,19 @@ Shader version: 500 ...@@ -387,31 +386,19 @@ Shader version: 500
Store 52 51 Store 52 51
53: 31(ptr) AccessChain 30 26 53: 31(ptr) AccessChain 30 26
54: 27 Load 53 54: 27 Load 53
57: 6(float) CompositeExtract 54 0 0 56: 7(fvec4) CompositeExtract 54 0
58: 6(float) CompositeExtract 54 0 1 57: 7(fvec4) CompositeExtract 54 1
59: 6(float) CompositeExtract 54 0 2 58: 7(fvec4) CompositeExtract 54 2
60: 6(float) CompositeExtract 54 0 3 59: 55 CompositeConstruct 56 57 58
61: 6(float) CompositeExtract 54 1 0 61: 60(ptr) AccessChain 13(input) 39
62: 6(float) CompositeExtract 54 1 1 62: 8(fvec3) Load 61
63: 6(float) CompositeExtract 54 1 2 63: 7(fvec4) MatrixTimesVector 59 62
64: 6(float) CompositeExtract 54 1 3 64: 6(float) CompositeExtract 63 0
65: 6(float) CompositeExtract 54 2 0 65: 6(float) CompositeExtract 63 1
66: 6(float) CompositeExtract 54 2 1 66: 6(float) CompositeExtract 63 2
67: 6(float) CompositeExtract 54 2 2 67: 8(fvec3) CompositeConstruct 64 65 66
68: 6(float) CompositeExtract 54 2 3 68: 60(ptr) AccessChain 21(output) 39
69: 7(fvec4) CompositeConstruct 57 58 59 60 Store 68 67
70: 7(fvec4) CompositeConstruct 61 62 63 64 69:11(PS_INPUT) Load 21(output)
71: 7(fvec4) CompositeConstruct 65 66 67 68 ReturnValue 69
72: 55 CompositeConstruct 69 70 71
74: 73(ptr) AccessChain 13(input) 39
75: 8(fvec3) Load 74
76: 7(fvec4) MatrixTimesVector 72 75
77: 6(float) CompositeExtract 76 0
78: 6(float) CompositeExtract 76 1
79: 6(float) CompositeExtract 76 2
80: 8(fvec3) CompositeConstruct 77 78 79
81: 73(ptr) AccessChain 21(output) 39
Store 81 80
82:11(PS_INPUT) Load 21(output)
ReturnValue 82
FunctionEnd FunctionEnd
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment