Commit e51f8593 by Ben Clayton

SpirvShader: Correctly handle phi values in the loop merge

Yet another horrible phi/loop edge case (pun intended). Added test. Bug: b/133440380 Bug: b/133481698 Change-Id: I327842fa2d4314bce938454da81f67f890cf9e12 Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/31845 Presubmit-Ready: Ben Clayton <bclayton@google.com> Kokoro-Presubmit: kokoro <noreply+kokoro@google.com> Tested-by: 's avatarBen Clayton <bclayton@google.com> Reviewed-by: 's avatarNicolas Capens <nicolascapens@google.com>
parent 2a985265
...@@ -984,7 +984,7 @@ namespace sw ...@@ -984,7 +984,7 @@ namespace sw
} }
ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName); ASSERT_MSG(entryPointFunctionId != 0, "Entry point '%s' not found", createInfo->pName);
AssignBlockIns(); AssignBlockFields();
} }
void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable) void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
...@@ -999,7 +999,7 @@ namespace sw ...@@ -999,7 +999,7 @@ namespace sw
} }
} }
void SpirvShader::AssignBlockIns() void SpirvShader::AssignBlockFields()
{ {
Block::Set reachable; Block::Set reachable;
TraverseReachableBlocks(entryPointBlockId, reachable); TraverseReachableBlocks(entryPointBlockId, reachable);
...@@ -1007,6 +1007,7 @@ namespace sw ...@@ -1007,6 +1007,7 @@ namespace sw
for (auto &it : blocks) for (auto &it : blocks)
{ {
auto &blockId = it.first; auto &blockId = it.first;
auto &block = it.second;
if (reachable.count(blockId) > 0) if (reachable.count(blockId) > 0)
{ {
for (auto &outId : it.second.outs) for (auto &outId : it.second.outs)
...@@ -1016,6 +1017,12 @@ namespace sw ...@@ -1016,6 +1017,12 @@ namespace sw
auto &out = outIt->second; auto &out = outIt->second;
out.ins.emplace(blockId); out.ins.emplace(blockId);
} }
if (block.kind == Block::Loop)
{
auto mergeIt = blocks.find(block.mergeBlock);
ASSERT_MSG(mergeIt != blocks.end(), "Loop block %d has a non-existent merge block %d", blockId.value(), block.mergeBlock.value());
mergeIt->second.isLoopMerge = true;
}
} }
} }
} }
...@@ -2083,7 +2090,9 @@ namespace sw ...@@ -2083,7 +2090,9 @@ namespace sw
void SpirvShader::EmitLoop(EmitState *state) const void SpirvShader::EmitLoop(EmitState *state) const
{ {
auto blockId = state->currentBlock; auto blockId = state->currentBlock;
auto block = getBlock(blockId); auto &block = getBlock(blockId);
auto mergeBlockId = block.mergeBlock;
auto &mergeBlock = getBlock(mergeBlockId);
// Ensure all incoming non-back edge blocks have been generated. // Ensure all incoming non-back edge blocks have been generated.
auto depsDone = true; auto depsDone = true;
...@@ -2091,7 +2100,7 @@ namespace sw ...@@ -2091,7 +2100,7 @@ namespace sw
{ {
if (state->visited.count(in) == 0) if (state->visited.count(in) == 0)
{ {
if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back edge
{ {
state->pending->emplace(in); state->pending->emplace(in);
depsDone = false; depsDone = false;
...@@ -2115,7 +2124,7 @@ namespace sw ...@@ -2115,7 +2124,7 @@ namespace sw
std::unordered_set<Block::ID> loopBlocks; std::unordered_set<Block::ID> loopBlocks;
for (auto in : block.ins) for (auto in : block.ins)
{ {
if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back-edge if (!existsPath(blockId, in, mergeBlockId)) // if not a loop back-edge
{ {
incomingBlocks.emplace(in); incomingBlocks.emplace(in);
} }
...@@ -2131,7 +2140,7 @@ namespace sw ...@@ -2131,7 +2140,7 @@ namespace sw
{ {
if (insn.opcode() == spv::OpPhi) if (insn.opcode() == spv::OpPhi)
{ {
StorePhi(insn, state, incomingBlocks); StorePhi(blockId, insn, state, incomingBlocks);
} }
} }
...@@ -2146,7 +2155,7 @@ namespace sw ...@@ -2146,7 +2155,7 @@ namespace sw
// mergeActiveLaneMasks contains edge lane masks for the merge block. // mergeActiveLaneMasks contains edge lane masks for the merge block.
// This is the union of all edge masks across all iterations of the loop. // This is the union of all edge masks across all iterations of the loop.
std::unordered_map<Block::ID, SIMD::Int> mergeActiveLaneMasks; std::unordered_map<Block::ID, SIMD::Int> mergeActiveLaneMasks;
for (auto in : getBlock(block.mergeBlock).ins) for (auto in : getBlock(mergeBlockId).ins)
{ {
mergeActiveLaneMasks.emplace(in, SIMD::Int(0)); mergeActiveLaneMasks.emplace(in, SIMD::Int(0));
} }
...@@ -2179,7 +2188,7 @@ namespace sw ...@@ -2179,7 +2188,7 @@ namespace sw
// don't emit the merge block yet. // don't emit the merge block yet.
for (auto out : block.outs) for (auto out : block.outs)
{ {
EmitBlocks(out, state, block.mergeBlock); EmitBlocks(out, state, mergeBlockId);
} }
// Restore current block id after emitting loop blocks. // Restore current block id after emitting loop blocks.
...@@ -2189,16 +2198,16 @@ namespace sw ...@@ -2189,16 +2198,16 @@ namespace sw
loopActiveLaneMask = SIMD::Int(0); loopActiveLaneMask = SIMD::Int(0);
for (auto in : block.ins) for (auto in : block.ins)
{ {
if (existsPath(blockId, in, block.mergeBlock)) if (existsPath(blockId, in, mergeBlockId))
{ {
loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId); loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
} }
} }
// Add active lanes to the merge lane mask. // Add active lanes to the merge lane mask.
for (auto in : getBlock(block.mergeBlock).ins) for (auto in : getBlock(mergeBlockId).ins)
{ {
auto edge = Block::Edge{in, block.mergeBlock}; auto edge = Block::Edge{in, mergeBlockId};
auto it = state->edgeActiveLaneMasks.find(edge); auto it = state->edgeActiveLaneMasks.find(edge);
if (it != state->edgeActiveLaneMasks.end()) if (it != state->edgeActiveLaneMasks.end())
{ {
...@@ -2211,7 +2220,39 @@ namespace sw ...@@ -2211,7 +2220,39 @@ namespace sw
{ {
if (insn.opcode() == spv::OpPhi) if (insn.opcode() == spv::OpPhi)
{ {
StorePhi(insn, state, loopBlocks); StorePhi(blockId, insn, state, loopBlocks);
}
}
// Use the [loop -> merge] active lane masks to update the phi values in
// the merge block. We need to do this to handle divergent control flow
// in the loop.
//
// Consider the following:
//
// int phi_source = 0;
// for (uint i = 0; i < 4; i++)
// {
// phi_source = 0;
// if (gl_GlobalInvocationID.x % 4 == i) // divergent control flow
// {
// phi_source = 42; // single lane assignment.
// break; // activeLaneMask for [loop->merge] is active for a single lane.
// }
// // -- we are here --
// }
// // merge block
// int phi = phi_source; // OpPhi
//
// In this example, with each iteration of the loop, phi_source will
// only have a single lane assigned. However by 'phi' value in the merge
// block needs to be assigned the union of all the per-lane assignments
// of phi_source when that lane exited the loop.
for (auto insn = mergeBlock.begin(); insn != mergeBlock.end(); insn++)
{
if (insn.opcode() == spv::OpPhi)
{
StorePhi(mergeBlockId, insn, state, mergeBlock.ins);
} }
} }
...@@ -2222,10 +2263,10 @@ namespace sw ...@@ -2222,10 +2263,10 @@ namespace sw
// Continue emitting from the merge block. // Continue emitting from the merge block.
Nucleus::setInsertBlock(mergeBasicBlock); Nucleus::setInsertBlock(mergeBasicBlock);
state->pending->emplace(block.mergeBlock); state->pending->emplace(mergeBlockId);
for (auto it : mergeActiveLaneMasks) for (auto it : mergeActiveLaneMasks)
{ {
state->addActiveLaneMaskEdge(it.first, block.mergeBlock, it.second); state->addActiveLaneMaskEdge(it.first, mergeBlockId, it.second);
} }
} }
...@@ -4618,7 +4659,13 @@ namespace sw ...@@ -4618,7 +4659,13 @@ namespace sw
SpirvShader::EmitResult SpirvShader::EmitPhi(InsnIterator insn, EmitState *state) const SpirvShader::EmitResult SpirvShader::EmitPhi(InsnIterator insn, EmitState *state) const
{ {
auto currentBlock = getBlock(state->currentBlock); auto currentBlock = getBlock(state->currentBlock);
StorePhi(insn, state, currentBlock.ins); if (!currentBlock.isLoopMerge)
{
// If this is a loop merge block, then don't attempt to update the
// phi values from the ins. EmitLoop() has had to take special care
// of this phi in order to correctly deal with divergent lanes.
StorePhi(state->currentBlock, insn, state, currentBlock.ins);
}
LoadPhi(insn, state); LoadPhi(insn, state);
return EmitResult::Continue; return EmitResult::Continue;
} }
...@@ -4641,13 +4688,12 @@ namespace sw ...@@ -4641,13 +4688,12 @@ namespace sw
} }
} }
void SpirvShader::StorePhi(InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const void SpirvShader::StorePhi(Block::ID currentBlock, InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const
{ {
auto routine = state->routine; auto routine = state->routine;
auto typeId = Type::ID(insn.word(1)); auto typeId = Type::ID(insn.word(1));
auto type = getType(typeId); auto type = getType(typeId);
auto objectId = Object::ID(insn.word(2)); auto objectId = Object::ID(insn.word(2));
auto currentBlock = getBlock(state->currentBlock);
auto storageIt = state->routine->phis.find(objectId); auto storageIt = state->routine->phis.find(objectId);
ASSERT(storageIt != state->routine->phis.end()); ASSERT(storageIt != state->routine->phis.end());
...@@ -4663,7 +4709,7 @@ namespace sw ...@@ -4663,7 +4709,7 @@ namespace sw
continue; continue;
} }
auto mask = GetActiveLaneMaskEdge(state, blockId, state->currentBlock); auto mask = GetActiveLaneMaskEdge(state, blockId, currentBlock);
auto in = GenericValue(this, routine, varId); auto in = GenericValue(this, routine, varId);
for (uint32_t i = 0; i < type.sizeInComponents; i++) for (uint32_t i = 0; i < type.sizeInComponents; i++)
......
...@@ -449,14 +449,14 @@ namespace sw ...@@ -449,14 +449,14 @@ namespace sw
Loop, // OpLoopMerge + [OpBranchConditional | OpBranch] Loop, // OpLoopMerge + [OpBranchConditional | OpBranch]
}; };
Kind kind; Kind kind = Simple;
InsnIterator mergeInstruction; // Structured control flow merge instruction. InsnIterator mergeInstruction; // Structured control flow merge instruction.
InsnIterator branchInstruction; // Branch instruction. InsnIterator branchInstruction; // Branch instruction.
ID mergeBlock; // Structured flow merge block. ID mergeBlock; // Structured flow merge block.
ID continueTarget; // Loop continue block. ID continueTarget; // Loop continue block.
Set ins; // Blocks that branch into this block. Set ins; // Blocks that branch into this block.
Set outs; // Blocks that this block branches to. Set outs; // Blocks that this block branches to.
bool isLoopMerge = false;
private: private:
InsnIterator begin_; InsnIterator begin_;
InsnIterator end_; InsnIterator end_;
...@@ -743,8 +743,12 @@ namespace sw ...@@ -743,8 +743,12 @@ namespace sw
// reachable. // reachable.
void TraverseReachableBlocks(Block::ID id, Block::Set& reachable); void TraverseReachableBlocks(Block::ID id, Block::Set& reachable);
// Assigns Block::ins from Block::outs for every block. // AssignBlockFields() performs the following for all reachable blocks:
void AssignBlockIns(); // * Assigns Block::ins with the identifiers of all blocks that contain
// this block in their Block::outs.
// * Sets Block::isLoopMerge to true if the block is the merge of a
// another loop block.
void AssignBlockFields();
// DeclareType creates a Type for the given OpTypeX instruction, storing // DeclareType creates a Type for the given OpTypeX instruction, storing
// it into the types map. It is called from the analysis pass (constructor). // it into the types map. It is called from the analysis pass (constructor).
...@@ -974,7 +978,7 @@ namespace sw ...@@ -974,7 +978,7 @@ namespace sw
// StorePhi updates the phi's alloca storage value using the incoming // StorePhi updates the phi's alloca storage value using the incoming
// values from blocks that are both in the OpPhi instruction and in // values from blocks that are both in the OpPhi instruction and in
// filter. // filter.
void StorePhi(InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const; void StorePhi(Block::ID blockID, InsnIterator insn, EmitState *state, std::unordered_set<SpirvShader::Block::ID> const& filter) const;
// Emits a rr::Fence for the given MemorySemanticsMask. // Emits a rr::Fence for the given MemorySemanticsMask.
void Fence(spv::MemorySemanticsMask semantics) const; void Fence(spv::MemorySemanticsMask semantics) const;
......
...@@ -149,7 +149,7 @@ std::vector<uint32_t> compileSpirv(const char* assembly) ...@@ -149,7 +149,7 @@ std::vector<uint32_t> compileSpirv(const char* assembly)
printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str()); printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());
} }
} }
printf("\n\n---\n"); printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());
} }
return spirv; return spirv;
...@@ -390,6 +390,20 @@ INSTANTIATE_TEST_CASE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTes ...@@ -390,6 +390,20 @@ INSTANTIATE_TEST_CASE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTes
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy) TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)
{ {
std::stringstream src; std::stringstream src;
// #version 450
// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
// layout(binding = 0, std430) buffer InBuffer
// {
// int Data[];
// } In;
// layout(binding = 1, std430) buffer OutBuffer
// {
// int Data[];
// } Out;
// void main()
// {
// Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];
// }
src << src <<
"OpCapability Shader\n" "OpCapability Shader\n"
"OpMemoryModel Logical GLSL450\n" "OpMemoryModel Logical GLSL450\n"
...@@ -1384,3 +1398,103 @@ TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi) ...@@ -1384,3 +1398,103 @@ TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)
test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; }); test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });
} }
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)
{
// #version 450
// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
// layout(binding = 0, std430) buffer InBuffer
// {
// int Data[];
// } In;
// layout(binding = 1, std430) buffer OutBuffer
// {
// int Data[];
// } Out;
// void main()
// {
// int phi = 0;
// uint lane = gl_GlobalInvocationID.x % 4;
// for (uint i = 0; i < 4; i++)
// {
// if (lane == i)
// {
// phi = In.Data[gl_GlobalInvocationID.x];
// break;
// }
// }
// Out.Data[gl_GlobalInvocationID.x] = phi;
// }
std::stringstream src;
src <<
"OpCapability Shader\n"
"%1 = OpExtInstImport \"GLSL.std.450\"\n"
"OpMemoryModel Logical GLSL450\n"
"OpEntryPoint GLCompute %2 \"main\" %3\n"
"OpExecutionMode %2 LocalSize " <<
GetParam().localSizeX << " " <<
GetParam().localSizeY << " " <<
GetParam().localSizeZ << "\n" <<
"OpDecorate %3 BuiltIn GlobalInvocationId\n"
"OpDecorate %4 ArrayStride 4\n"
"OpMemberDecorate %5 0 Offset 0\n"
"OpDecorate %5 BufferBlock\n"
"OpDecorate %6 DescriptorSet 0\n"
"OpDecorate %6 Binding 0\n"
"OpDecorate %7 ArrayStride 4\n"
"OpMemberDecorate %8 0 Offset 0\n"
"OpDecorate %8 BufferBlock\n"
"OpDecorate %9 DescriptorSet 0\n"
"OpDecorate %9 Binding 1\n"
"%10 = OpTypeVoid\n"
"%11 = OpTypeFunction %10\n"
"%12 = OpTypeInt 32 1\n"
"%13 = OpConstant %12 0\n"
"%14 = OpTypeInt 32 0\n"
"%15 = OpTypeVector %14 3\n"
"%16 = OpTypePointer Input %15\n"
"%3 = OpVariable %16 Input\n"
"%17 = OpConstant %14 0\n"
"%18 = OpTypePointer Input %14\n"
"%19 = OpConstant %14 4\n"
"%20 = OpTypeBool\n"
"%4 = OpTypeRuntimeArray %12\n"
"%5 = OpTypeStruct %4\n"
"%21 = OpTypePointer Uniform %5\n"
"%6 = OpVariable %21 Uniform\n"
"%22 = OpTypePointer Uniform %12\n"
"%23 = OpConstant %12 1\n"
"%7 = OpTypeRuntimeArray %12\n"
"%8 = OpTypeStruct %7\n"
"%24 = OpTypePointer Uniform %8\n"
"%9 = OpVariable %24 Uniform\n"
"%2 = OpFunction %10 None %11\n"
"%25 = OpLabel\n"
"%26 = OpAccessChain %18 %3 %17\n"
"%27 = OpLoad %14 %26\n"
"%28 = OpUMod %14 %27 %19\n"
"OpBranch %29\n"
"%29 = OpLabel\n"
"%30 = OpPhi %14 %17 %25 %31 %32\n"
"%33 = OpULessThan %20 %30 %19\n"
"OpLoopMerge %34 %32 None\n"
"OpBranchConditional %33 %35 %34\n"
"%35 = OpLabel\n"
"%36 = OpIEqual %20 %28 %30\n"
"OpSelectionMerge %32 None\n"
"OpBranchConditional %36 %37 %32\n"
"%37 = OpLabel\n"
"%38 = OpAccessChain %22 %6 %13 %27\n"
"%39 = OpLoad %12 %38\n"
"OpBranch %34\n"
"%32 = OpLabel\n"
"%31 = OpIAdd %14 %30 %23\n"
"OpBranch %29\n"
"%34 = OpLabel\n"
"%40 = OpPhi %12 %13 %29 %39 %37\n" // %39: phi
"%41 = OpAccessChain %22 %9 %13 %27\n"
"OpStore %41 %40\n"
"OpReturn\n"
"OpFunctionEnd\n";
test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment