Commit 13dcbece by Ben Clayton

src/Pipeline: Refactor ComputeProgram

• Split up ComputeProgram::emit() into a few smaller functions. • Calculate the subgroup count in C++, and pass it in as a parameter. • Add a firstSubgroup, this is currently always 0, but will be used in a later change. • Pass the workgroup ID as parameters instead of through Data. Data now holds fields common for all workgroups. This refactoring prepares the code for migrating to coroutines. Bug: b/131672705 Change-Id: Id3492adc0a7aedc3f16c0e37f135294862c55700 Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/30848Tested-by: 's avatarBen Clayton <bclayton@google.com> Presubmit-Ready: Ben Clayton <bclayton@google.com> Reviewed-by: 's avatarNicolas Capens <nicolascapens@google.com> Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
parent 1c82c7b8
...@@ -45,30 +45,11 @@ namespace sw ...@@ -45,30 +45,11 @@ namespace sw
shader->emitEpilog(&routine); shader->emitEpilog(&routine);
} }
void ComputeProgram::emit() void ComputeProgram::setWorkgroupBuiltins(Int workgroupID[3])
{ {
routine.descriptorSets = data + OFFSET(Data, descriptorSets);
routine.descriptorDynamicOffsets = data + OFFSET(Data, descriptorDynamicOffsets);
routine.pushConstants = data + OFFSET(Data, pushConstants);
routine.constants = *Pointer<Pointer<Byte>>(data + OFFSET(Data, constants));
routine.workgroupMemory = *Pointer<Pointer<Byte>>(data + OFFSET(Data, workgroupMemory));
auto &modes = shader->getModes();
int localSize[3] = {modes.WorkgroupSizeX, modes.WorkgroupSizeY, modes.WorkgroupSizeZ};
const int subgroupSize = SIMD::Width;
// Total number of invocations required to execute this workgroup.
int numInvocations = localSize[X] * localSize[Y] * localSize[Z];
Int4 numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
Int4 workgroupID = *Pointer<Int4>(data + OFFSET(Data, workgroupID));
Int4 workgroupSize = Int4(localSize[X], localSize[Y], localSize[Z], 0);
Int numSubgroups = (numInvocations + subgroupSize - 1) / subgroupSize;
setInputBuiltin(spv::BuiltInNumWorkgroups, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value) setInputBuiltin(spv::BuiltInNumWorkgroups, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
{ {
auto numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
for (uint32_t component = 0; component < builtin.SizeInComponents; component++) for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
{ {
value[builtin.FirstComponent + component] = value[builtin.FirstComponent + component] =
...@@ -81,12 +62,13 @@ namespace sw ...@@ -81,12 +62,13 @@ namespace sw
for (uint32_t component = 0; component < builtin.SizeInComponents; component++) for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
{ {
value[builtin.FirstComponent + component] = value[builtin.FirstComponent + component] =
As<SIMD::Float>(SIMD::Int(Extract(workgroupID, component))); As<SIMD::Float>(SIMD::Int(workgroupID[component]));
} }
}); });
setInputBuiltin(spv::BuiltInWorkgroupSize, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value) setInputBuiltin(spv::BuiltInWorkgroupSize, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
{ {
auto workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
for (uint32_t component = 0; component < builtin.SizeInComponents; component++) for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
{ {
value[builtin.FirstComponent + component] = value[builtin.FirstComponent + component] =
...@@ -97,13 +79,15 @@ namespace sw ...@@ -97,13 +79,15 @@ namespace sw
setInputBuiltin(spv::BuiltInNumSubgroups, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value) setInputBuiltin(spv::BuiltInNumSubgroups, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
{ {
ASSERT(builtin.SizeInComponents == 1); ASSERT(builtin.SizeInComponents == 1);
value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(numSubgroups)); auto subgroupsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, subgroupsPerWorkgroup));
value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupsPerWorkgroup));
}); });
setInputBuiltin(spv::BuiltInSubgroupSize, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value) setInputBuiltin(spv::BuiltInSubgroupSize, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
{ {
ASSERT(builtin.SizeInComponents == 1); ASSERT(builtin.SizeInComponents == 1);
value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(subgroupSize)); auto invocationsPerSubgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerSubgroup));
value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(invocationsPerSubgroup));
}); });
setInputBuiltin(spv::BuiltInSubgroupLocalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value) setInputBuiltin(spv::BuiltInSubgroupLocalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
...@@ -111,22 +95,24 @@ namespace sw ...@@ -111,22 +95,24 @@ namespace sw
ASSERT(builtin.SizeInComponents == 1); ASSERT(builtin.SizeInComponents == 1);
value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(0, 1, 2, 3)); value[builtin.FirstComponent] = As<SIMD::Float>(SIMD::Int(0, 1, 2, 3));
}); });
}
For(Int subgroupIndex = 0, subgroupIndex < numSubgroups, subgroupIndex++) void ComputeProgram::setSubgroupBuiltins(Int workgroupID[3], SIMD::Int localInvocationIndex, Int subgroupIndex)
{ {
// TODO: Replace SIMD::Int(0, 1, 2, 3) with SIMD-width equivalent Int4 numWorkgroups = *Pointer<Int4>(data + OFFSET(Data, numWorkgroups));
auto localInvocationIndex = SIMD::Int(subgroupIndex * SIMD::Width) + SIMD::Int(0, 1, 2, 3); Int4 workgroupSize = *Pointer<Int4>(data + OFFSET(Data, workgroupSize));
// Disable lanes where (invocationIDs >= numInvocations) // TODO: Fix Int4 swizzles so we can just use workgroupSize.x, workgroupSize.y.
auto activeLaneMask = CmpLT(localInvocationIndex, SIMD::Int(numInvocations)); Int workgroupSizeX = Extract(workgroupSize, X);
Int workgroupSizeY = Extract(workgroupSize, Y);
SIMD::Int localInvocationID[3]; SIMD::Int localInvocationID[3];
{ {
SIMD::Int idx = localInvocationIndex; SIMD::Int idx = localInvocationIndex;
localInvocationID[Z] = idx / SIMD::Int(localSize[X] * localSize[Y]); localInvocationID[Z] = idx / SIMD::Int(workgroupSizeX * workgroupSizeY);
idx -= localInvocationID[Z] * SIMD::Int(localSize[X] * localSize[Y]); // modulo idx -= localInvocationID[Z] * SIMD::Int(workgroupSizeX * workgroupSizeY); // modulo
localInvocationID[Y] = idx / SIMD::Int(localSize[X]); localInvocationID[Y] = idx / SIMD::Int(workgroupSizeX);
idx -= localInvocationID[Y] * SIMD::Int(localSize[X]); // modulo idx -= localInvocationID[Y] * SIMD::Int(workgroupSizeX); // modulo
localInvocationID[X] = idx; localInvocationID[X] = idx;
} }
...@@ -146,21 +132,57 @@ namespace sw ...@@ -146,21 +132,57 @@ namespace sw
{ {
for (uint32_t component = 0; component < builtin.SizeInComponents; component++) for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
{ {
value[builtin.FirstComponent + component] = As<SIMD::Float>(localInvocationID[component]); value[builtin.FirstComponent + component] =
As<SIMD::Float>(localInvocationID[component]);
} }
}); });
setInputBuiltin(spv::BuiltInGlobalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value) setInputBuiltin(spv::BuiltInGlobalInvocationId, [&](const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)
{ {
auto localBase = workgroupID * workgroupSize; SIMD::Int wgID = 0;
wgID = Insert(wgID, workgroupID[X], X);
wgID = Insert(wgID, workgroupID[Y], Y);
wgID = Insert(wgID, workgroupID[Z], Z);
auto localBase = workgroupSize * wgID;
for (uint32_t component = 0; component < builtin.SizeInComponents; component++) for (uint32_t component = 0; component < builtin.SizeInComponents; component++)
{ {
auto globalInvocationID = SIMD::Int(Extract(localBase, component)) + localInvocationID[component]; auto globalInvocationID = SIMD::Int(Extract(localBase, component)) + localInvocationID[component];
value[builtin.FirstComponent + component] = As<SIMD::Float>(globalInvocationID); value[builtin.FirstComponent + component] = As<SIMD::Float>(globalInvocationID);
} }
}); });
}
void ComputeProgram::emit()
{
routine.descriptorSets = data + OFFSET(Data, descriptorSets);
routine.descriptorDynamicOffsets = data + OFFSET(Data, descriptorDynamicOffsets);
routine.pushConstants = data + OFFSET(Data, pushConstants);
routine.constants = *Pointer<Pointer<Byte>>(data + OFFSET(Data, constants));
routine.workgroupMemory = *Pointer<Pointer<Byte>>(data + OFFSET(Data, workgroupMemory));
Int workgroupX = Arg<1>();
Int workgroupY = Arg<2>();
Int workgroupZ = Arg<3>();
Int firstSubgroup = Arg<4>();
Int subgroupCount = Arg<5>();
Int invocationsPerWorkgroup = *Pointer<Int>(data + OFFSET(Data, invocationsPerWorkgroup));
Int workgroupID[3] = {workgroupX, workgroupY, workgroupZ};
setWorkgroupBuiltins(workgroupID);
For(Int i = 0, i < subgroupCount, i++)
{
auto subgroupIndex = firstSubgroup + i;
// TODO: Replace SIMD::Int(0, 1, 2, 3) with SIMD-width equivalent
auto localInvocationIndex = SIMD::Int(subgroupIndex * SIMD::Width) + SIMD::Int(0, 1, 2, 3);
// Disable lanes where (invocationIDs >= invocationsPerWorkgroup)
auto activeLaneMask = CmpLT(localInvocationIndex, SIMD::Int(invocationsPerWorkgroup));
setSubgroupBuiltins(workgroupID, localInvocationIndex, subgroupIndex);
// Process numLanes of the workgroup.
shader->emit(&routine, activeLaneMask, descriptorSets); shader->emit(&routine, activeLaneMask, descriptorSets);
} }
} }
...@@ -182,7 +204,13 @@ namespace sw ...@@ -182,7 +204,13 @@ namespace sw
PushConstantStorage const &pushConstants, PushConstantStorage const &pushConstants,
uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ) uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ)
{ {
auto runWorkgroup = (void(*)(void*))(routine->getEntry()); auto runWorkgroup = (void(*)(void*, int, int, int, int, int))(routine->getEntry());
auto &modes = shader->getModes();
auto invocationsPerSubgroup = SIMD::Width;
auto invocationsPerWorkgroup = modes.WorkgroupSizeX * modes.WorkgroupSizeY * modes.WorkgroupSizeZ;
auto subgroupsPerWorkgroup = (invocationsPerWorkgroup + invocationsPerSubgroup - 1) / invocationsPerSubgroup;
// We're sharing a buffer here across all workgroups. // We're sharing a buffer here across all workgroups.
// We can only do this because we know workgroups are executed // We can only do this because we know workgroups are executed
...@@ -196,6 +224,13 @@ namespace sw ...@@ -196,6 +224,13 @@ namespace sw
data.numWorkgroups[Y] = groupCountY; data.numWorkgroups[Y] = groupCountY;
data.numWorkgroups[Z] = groupCountZ; data.numWorkgroups[Z] = groupCountZ;
data.numWorkgroups[3] = 0; data.numWorkgroups[3] = 0;
data.workgroupSize[X] = modes.WorkgroupSizeX;
data.workgroupSize[Y] = modes.WorkgroupSizeY;
data.workgroupSize[Z] = modes.WorkgroupSizeZ;
data.workgroupSize[3] = 0;
data.invocationsPerSubgroup = invocationsPerSubgroup;
data.invocationsPerWorkgroup = invocationsPerWorkgroup;
data.subgroupsPerWorkgroup = subgroupsPerWorkgroup;
data.pushConstants = pushConstants; data.pushConstants = pushConstants;
data.constants = &sw::constants; data.constants = &sw::constants;
data.workgroupMemory = workgroupMemory.data(); data.workgroupMemory = workgroupMemory.data();
...@@ -203,16 +238,14 @@ namespace sw ...@@ -203,16 +238,14 @@ namespace sw
// TODO(bclayton): Split work across threads. // TODO(bclayton): Split work across threads.
for (uint32_t groupZ = 0; groupZ < groupCountZ; groupZ++) for (uint32_t groupZ = 0; groupZ < groupCountZ; groupZ++)
{ {
data.workgroupID[Z] = groupZ;
for (uint32_t groupY = 0; groupY < groupCountY; groupY++) for (uint32_t groupY = 0; groupY < groupCountY; groupY++)
{ {
data.workgroupID[Y] = groupY;
for (uint32_t groupX = 0; groupX < groupCountX; groupX++) for (uint32_t groupX = 0; groupX < groupCountX; groupX++)
{ {
data.workgroupID[X] = groupX; runWorkgroup(&data, groupX, groupY, groupZ, 0, subgroupsPerWorkgroup);
runWorkgroup(&data);
} }
} }
} }
} }
}
} // namespace sw
...@@ -37,7 +37,13 @@ namespace sw ...@@ -37,7 +37,13 @@ namespace sw
struct Constants; struct Constants;
// ComputeProgram builds a SPIR-V compute shader. // ComputeProgram builds a SPIR-V compute shader.
class ComputeProgram : public Function<Void(Pointer<Byte>)> class ComputeProgram : public Function<Void(
Pointer<Byte> data,
Int workgroupX,
Int workgroupY,
Int workgroupZ,
Int firstSubgroup,
Int subgroupCount)>
{ {
public: public:
ComputeProgram(SpirvShader const *spirvShader, vk::PipelineLayout const *pipelineLayout, const vk::DescriptorSet::Bindings &descriptorSets); ComputeProgram(SpirvShader const *spirvShader, vk::PipelineLayout const *pipelineLayout, const vk::DescriptorSet::Bindings &descriptorSets);
...@@ -59,6 +65,8 @@ namespace sw ...@@ -59,6 +65,8 @@ namespace sw
protected: protected:
void emit(); void emit();
void setWorkgroupBuiltins(Int workgroupID[3]);
void setSubgroupBuiltins(Int workgroupID[3], SIMD::Int localInvocationIndex, Int subgroupIndex);
void setInputBuiltin(spv::BuiltIn id, std::function<void(const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)> cb); void setInputBuiltin(spv::BuiltIn id, std::function<void(const SpirvShader::BuiltinMapping& builtin, Array<SIMD::Float>& value)> cb);
Pointer<Byte> data; // argument 0 Pointer<Byte> data; // argument 0
...@@ -67,8 +75,11 @@ namespace sw ...@@ -67,8 +75,11 @@ namespace sw
{ {
vk::DescriptorSet::Bindings descriptorSets; vk::DescriptorSet::Bindings descriptorSets;
vk::DescriptorSet::DynamicOffsets descriptorDynamicOffsets; vk::DescriptorSet::DynamicOffsets descriptorDynamicOffsets;
uint4 numWorkgroups; uint4 numWorkgroups; // [x, y, z, 0]
uint4 workgroupID; uint4 workgroupSize; // [x, y, z, 0]
uint32_t invocationsPerSubgroup; // SPIR-V: "SubgroupSize"
uint32_t subgroupsPerWorkgroup; // SPIR-V: "NumSubgroups"
uint32_t invocationsPerWorkgroup; // Total number of invocations per workgroup.
PushConstantStorage pushConstants; PushConstantStorage pushConstants;
const Constants *constants; const Constants *constants;
uint8_t* workgroupMemory; uint8_t* workgroupMemory;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment