Commit 112e2858 by Arseny Kapoulkine

SPIRV: Simplify matrix->matrix constructor

When constructing a matrix from another matrix with smaller dimensions, there's no need to extract the scalars out of columns and rebuild the resulting matrix from scalars - instead, we can just construct shorter vectors with OpShuffle and combine them to the final result. This keeps the common casts such as mat3(mat4) in vector registers, which may improve performance for some GPUs, and cleans up output of translation tools like SPIRV-Cross. Fixes #1412.
parent cd57b4ba
......@@ -2031,83 +2031,113 @@ Id Builder::createMatrixConstructor(Decoration precision, const std::vector<Id>&
Instruction* instr = module.getInstruction(componentTypeId);
Id bitCount = instr->getIdOperand(0);
// Will use a two step process
// 1. make a compile-time 2D array of values
// 2. construct a matrix from that array
// Step 1.
// initialize the array to the identity matrix
Id ids[maxMatrixSize][maxMatrixSize];
Id one = (bitCount == 64 ? makeDoubleConstant(1.0) : makeFloatConstant(1.0));
Id zero = (bitCount == 64 ? makeDoubleConstant(0.0) : makeFloatConstant(0.0));
for (int col = 0; col < 4; ++col) {
for (int row = 0; row < 4; ++row) {
if (col == row)
ids[col][row] = one;
else
ids[col][row] = zero;
if (isMatrix(sources[0]) && getNumColumns(sources[0]) >= numCols && getNumRows(sources[0]) >= numRows) {
// To truncate the matrix to a smaller number of rows/columns, we need to:
// 1. For each column, extract the column and truncate it to the required size using shuffle
// 2. Assemble the resulting matrix from all columns
Id matrix = sources[0];
Id columnTypeId = getContainedTypeId(resultTypeId);
Id sourceColumnTypeId = getContainedTypeId(getTypeId(matrix));
std::vector<unsigned> channels;
for (int row = 0; row < numRows; ++row) {
channels.push_back(row);
}
}
// modify components as dictated by the arguments
if (sources.size() == 1 && isScalar(sources[0])) {
// a single scalar; resets the diagonals
for (int col = 0; col < 4; ++col)
ids[col][col] = sources[0];
} else if (isMatrix(sources[0])) {
// constructing from another matrix; copy over the parts that exist in both the argument and constructee
Id matrix = sources[0];
int minCols = std::min(numCols, getNumColumns(matrix));
int minRows = std::min(numRows, getNumRows(matrix));
for (int col = 0; col < minCols; ++col) {
std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
std::vector<unsigned> indexes;
indexes.push_back(col);
for (int row = 0; row < minRows; ++row) {
indexes.push_back(row);
ids[col][row] = createCompositeExtract(matrix, componentTypeId, indexes);
indexes.pop_back();
setPrecision(ids[col][row], precision);
Id colv = createCompositeExtract(matrix, sourceColumnTypeId, indexes);
setPrecision(colv, precision);
if (numRows != getNumRows(matrix)) {
matrixColumns.push_back(createRvalueSwizzle(precision, columnTypeId, colv, channels));
} else {
matrixColumns.push_back(colv);
}
}
return setPrecision(createCompositeConstruct(resultTypeId, matrixColumns), precision);
} else {
// fill in the matrix in column-major order with whatever argument components are available
int row = 0;
int col = 0;
// Will use a two step process
// 1. make a compile-time 2D array of values
// 2. construct a matrix from that array
// Step 1.
// initialize the array to the identity matrix
Id ids[maxMatrixSize][maxMatrixSize];
Id one = (bitCount == 64 ? makeDoubleConstant(1.0) : makeFloatConstant(1.0));
Id zero = (bitCount == 64 ? makeDoubleConstant(0.0) : makeFloatConstant(0.0));
for (int col = 0; col < 4; ++col) {
for (int row = 0; row < 4; ++row) {
if (col == row)
ids[col][row] = one;
else
ids[col][row] = zero;
}
}
for (int arg = 0; arg < (int)sources.size(); ++arg) {
Id argComp = sources[arg];
for (int comp = 0; comp < getNumComponents(sources[arg]); ++comp) {
if (getNumComponents(sources[arg]) > 1) {
argComp = createCompositeExtract(sources[arg], componentTypeId, comp);
setPrecision(argComp, precision);
// modify components as dictated by the arguments
if (sources.size() == 1 && isScalar(sources[0])) {
// a single scalar; resets the diagonals
for (int col = 0; col < 4; ++col)
ids[col][col] = sources[0];
} else if (isMatrix(sources[0])) {
// constructing from another matrix; copy over the parts that exist in both the argument and constructee
Id matrix = sources[0];
int minCols = std::min(numCols, getNumColumns(matrix));
int minRows = std::min(numRows, getNumRows(matrix));
for (int col = 0; col < minCols; ++col) {
std::vector<unsigned> indexes;
indexes.push_back(col);
for (int row = 0; row < minRows; ++row) {
indexes.push_back(row);
ids[col][row] = createCompositeExtract(matrix, componentTypeId, indexes);
indexes.pop_back();
setPrecision(ids[col][row], precision);
}
ids[col][row++] = argComp;
if (row == numRows) {
row = 0;
col++;
}
} else {
// fill in the matrix in column-major order with whatever argument components are available
int row = 0;
int col = 0;
for (int arg = 0; arg < (int)sources.size(); ++arg) {
Id argComp = sources[arg];
for (int comp = 0; comp < getNumComponents(sources[arg]); ++comp) {
if (getNumComponents(sources[arg]) > 1) {
argComp = createCompositeExtract(sources[arg], componentTypeId, comp);
setPrecision(argComp, precision);
}
ids[col][row++] = argComp;
if (row == numRows) {
row = 0;
col++;
}
}
}
}
}
// Step 2: Construct a matrix from that array.
// First make the column vectors, then make the matrix.
// Step 2: Construct a matrix from that array.
// First make the column vectors, then make the matrix.
// make the column vectors
Id columnTypeId = getContainedTypeId(resultTypeId);
std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
std::vector<Id> vectorComponents;
for (int row = 0; row < numRows; ++row)
vectorComponents.push_back(ids[col][row]);
Id column = createCompositeConstruct(columnTypeId, vectorComponents);
setPrecision(column, precision);
matrixColumns.push_back(column);
}
// make the column vectors
Id columnTypeId = getContainedTypeId(resultTypeId);
std::vector<Id> matrixColumns;
for (int col = 0; col < numCols; ++col) {
std::vector<Id> vectorComponents;
for (int row = 0; row < numRows; ++row)
vectorComponents.push_back(ids[col][row]);
Id column = createCompositeConstruct(columnTypeId, vectorComponents);
setPrecision(column, precision);
matrixColumns.push_back(column);
// make the matrix
return setPrecision(createCompositeConstruct(resultTypeId, matrixColumns), precision);
}
// make the matrix
return setPrecision(createCompositeConstruct(resultTypeId, matrixColumns), precision);
}
// Comments in header
......
......@@ -251,12 +251,12 @@ Shader version: 500
// Module Version 10000
// Generated by (magic number): 80007
// Id's are bound by 106
// Id's are bound by 93
Capability Shader
1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450
EntryPoint Vertex 4 "main" 87 91 99 103
EntryPoint Vertex 4 "main" 74 78 86 90
Source HLSL 500
Name 4 "main"
Name 9 "VS_INPUT"
......@@ -274,13 +274,13 @@ Shader version: 500
MemberName 28(C) 1 "View"
MemberName 28(C) 2 "Projection"
Name 30 ""
Name 85 "input"
Name 87 "input.Pos"
Name 91 "input.Norm"
Name 94 "flattenTemp"
Name 95 "param"
Name 99 "@entryPointOutput.Pos"
Name 103 "@entryPointOutput.Norm"
Name 72 "input"
Name 74 "input.Pos"
Name 78 "input.Norm"
Name 81 "flattenTemp"
Name 82 "param"
Name 86 "@entryPointOutput.Pos"
Name 90 "@entryPointOutput.Norm"
MemberDecorate 28(C) 0 RowMajor
MemberDecorate 28(C) 0 Offset 0
MemberDecorate 28(C) 0 MatrixStride 16
......@@ -293,10 +293,10 @@ Shader version: 500
Decorate 28(C) Block
Decorate 30 DescriptorSet 0
Decorate 30 Binding 0
Decorate 87(input.Pos) Location 0
Decorate 91(input.Norm) Location 1
Decorate 99(@entryPointOutput.Pos) BuiltIn Position
Decorate 103(@entryPointOutput.Norm) Location 0
Decorate 74(input.Pos) Location 0
Decorate 78(input.Norm) Location 1
Decorate 86(@entryPointOutput.Pos) BuiltIn Position
Decorate 90(@entryPointOutput.Norm) Location 0
2: TypeVoid
3: TypeFunction 2
6: TypeFloat 32
......@@ -324,37 +324,36 @@ Shader version: 500
39: 16(int) Constant 1
46: 16(int) Constant 2
55: TypeMatrix 7(fvec4) 3
56: 6(float) Constant 1065353216
73: TypePointer Function 8(fvec3)
86: TypePointer Input 7(fvec4)
87(input.Pos): 86(ptr) Variable Input
90: TypePointer Input 8(fvec3)
91(input.Norm): 90(ptr) Variable Input
98: TypePointer Output 7(fvec4)
99(@entryPointOutput.Pos): 98(ptr) Variable Output
102: TypePointer Output 8(fvec3)
103(@entryPointOutput.Norm): 102(ptr) Variable Output
60: TypePointer Function 8(fvec3)
73: TypePointer Input 7(fvec4)
74(input.Pos): 73(ptr) Variable Input
77: TypePointer Input 8(fvec3)
78(input.Norm): 77(ptr) Variable Input
85: TypePointer Output 7(fvec4)
86(@entryPointOutput.Pos): 85(ptr) Variable Output
89: TypePointer Output 8(fvec3)
90(@entryPointOutput.Norm): 89(ptr) Variable Output
4(main): 2 Function None 3
5: Label
85(input): 10(ptr) Variable Function
94(flattenTemp): 20(ptr) Variable Function
95(param): 10(ptr) Variable Function
88: 7(fvec4) Load 87(input.Pos)
89: 34(ptr) AccessChain 85(input) 26
Store 89 88
92: 8(fvec3) Load 91(input.Norm)
93: 73(ptr) AccessChain 85(input) 39
Store 93 92
96: 9(VS_INPUT) Load 85(input)
Store 95(param) 96
97:11(PS_INPUT) FunctionCall 14(@main(struct-VS_INPUT-vf4-vf31;) 95(param)
Store 94(flattenTemp) 97
100: 34(ptr) AccessChain 94(flattenTemp) 26
101: 7(fvec4) Load 100
Store 99(@entryPointOutput.Pos) 101
104: 73(ptr) AccessChain 94(flattenTemp) 39
105: 8(fvec3) Load 104
Store 103(@entryPointOutput.Norm) 105
72(input): 10(ptr) Variable Function
81(flattenTemp): 20(ptr) Variable Function
82(param): 10(ptr) Variable Function
75: 7(fvec4) Load 74(input.Pos)
76: 34(ptr) AccessChain 72(input) 26
Store 76 75
79: 8(fvec3) Load 78(input.Norm)
80: 60(ptr) AccessChain 72(input) 39
Store 80 79
83: 9(VS_INPUT) Load 72(input)
Store 82(param) 83
84:11(PS_INPUT) FunctionCall 14(@main(struct-VS_INPUT-vf4-vf31;) 82(param)
Store 81(flattenTemp) 84
87: 34(ptr) AccessChain 81(flattenTemp) 26
88: 7(fvec4) Load 87
Store 86(@entryPointOutput.Pos) 88
91: 60(ptr) AccessChain 81(flattenTemp) 39
92: 8(fvec3) Load 91
Store 90(@entryPointOutput.Norm) 92
Return
FunctionEnd
14(@main(struct-VS_INPUT-vf4-vf31;):11(PS_INPUT) Function None 12
......@@ -387,31 +386,19 @@ Shader version: 500
Store 52 51
53: 31(ptr) AccessChain 30 26
54: 27 Load 53
57: 6(float) CompositeExtract 54 0 0
58: 6(float) CompositeExtract 54 0 1
59: 6(float) CompositeExtract 54 0 2
60: 6(float) CompositeExtract 54 0 3
61: 6(float) CompositeExtract 54 1 0
62: 6(float) CompositeExtract 54 1 1
63: 6(float) CompositeExtract 54 1 2
64: 6(float) CompositeExtract 54 1 3
65: 6(float) CompositeExtract 54 2 0
66: 6(float) CompositeExtract 54 2 1
67: 6(float) CompositeExtract 54 2 2
68: 6(float) CompositeExtract 54 2 3
69: 7(fvec4) CompositeConstruct 57 58 59 60
70: 7(fvec4) CompositeConstruct 61 62 63 64
71: 7(fvec4) CompositeConstruct 65 66 67 68
72: 55 CompositeConstruct 69 70 71
74: 73(ptr) AccessChain 13(input) 39
75: 8(fvec3) Load 74
76: 7(fvec4) MatrixTimesVector 72 75
77: 6(float) CompositeExtract 76 0
78: 6(float) CompositeExtract 76 1
79: 6(float) CompositeExtract 76 2
80: 8(fvec3) CompositeConstruct 77 78 79
81: 73(ptr) AccessChain 21(output) 39
Store 81 80
82:11(PS_INPUT) Load 21(output)
ReturnValue 82
56: 7(fvec4) CompositeExtract 54 0
57: 7(fvec4) CompositeExtract 54 1
58: 7(fvec4) CompositeExtract 54 2
59: 55 CompositeConstruct 56 57 58
61: 60(ptr) AccessChain 13(input) 39
62: 8(fvec3) Load 61
63: 7(fvec4) MatrixTimesVector 59 62
64: 6(float) CompositeExtract 63 0
65: 6(float) CompositeExtract 63 1
66: 6(float) CompositeExtract 63 2
67: 8(fvec3) CompositeConstruct 64 65 66
68: 60(ptr) AccessChain 21(output) 39
Store 68 67
69:11(PS_INPUT) Load 21(output)
ReturnValue 69
FunctionEnd
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment