Commit e48f2210 by Shahbaz Youssefi Committed by Angle LUCI CQ

Vulkan: SPIR-V Gen: Support short-circuiting && and ||

When short-circuiting is necessary (because the right-hand side has side effects), the following code is generated in SPIR-V: // For left && right: result = left if (left) result = right // For left || right: result = left if (!left) result = right Bug: angleproject:4889 Change-Id: Id87b56dc4a1463ed781852a23d2ba6eb2015d700 Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/2953366 Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarTim Van Patten <timvp@google.com>
parent a76f224f
...@@ -1126,6 +1126,39 @@ void SPIRVBuilder::writeInterfaceVariableDecorations(const TType &type, spirv::I ...@@ -1126,6 +1126,39 @@ void SPIRVBuilder::writeInterfaceVariableDecorations(const TType &type, spirv::I
} }
} }
void SPIRVBuilder::writeBranchConditional(spirv::IdRef conditionValue,
spirv::IdRef trueBlock,
spirv::IdRef falseBlock,
spirv::IdRef mergeBlock)
{
// Generate the following:
//
// OpSelectionMerge %mergeBlock None
// OpBranchConditional %conditionValue %trueBlock %falseBlock
//
spirv::WriteSelectionMerge(getSpirvCurrentFunctionBlock(), mergeBlock,
spv::SelectionControlMaskNone);
spirv::WriteBranchConditional(getSpirvCurrentFunctionBlock(), conditionValue, trueBlock,
falseBlock, {});
terminateCurrentFunctionBlock();
// Start the true or false block, whichever exists.
nextConditionalBlock();
}
void SPIRVBuilder::writeBranchConditionalBlockEnd()
{
// Insert a branch to the merge block at the end of each if-else block.
const spirv::IdRef mergeBlock = getCurrentConditional()->blockIds.back();
ASSERT(!isCurrentFunctionBlockTerminated());
spirv::WriteBranch(getSpirvCurrentFunctionBlock(), mergeBlock);
terminateCurrentFunctionBlock();
// Move on to the next block.
nextConditionalBlock();
}
uint32_t SPIRVBuilder::calculateBaseAlignmentAndSize(const SpirvType &type, uint32_t SPIRVBuilder::calculateBaseAlignmentAndSize(const SpirvType &type,
uint32_t *sizeInStorageBlockOut, uint32_t *sizeInStorageBlockOut,
uint32_t *matrixStrideOut) uint32_t *matrixStrideOut)
...@@ -1411,6 +1444,8 @@ ImmutableString SPIRVBuilder::hashFunctionName(const TFunction *func) ...@@ -1411,6 +1444,8 @@ ImmutableString SPIRVBuilder::hashFunctionName(const TFunction *func)
spirv::Blob SPIRVBuilder::getSpirv() spirv::Blob SPIRVBuilder::getSpirv()
{ {
ASSERT(mConditionalStack.empty());
spirv::Blob result; spirv::Blob result;
// Reserve a minimum amount of memory. // Reserve a minimum amount of memory.
......
...@@ -279,6 +279,12 @@ class SPIRVBuilder : angle::NonCopyable ...@@ -279,6 +279,12 @@ class SPIRVBuilder : angle::NonCopyable
!mSpirvCurrentFunctionBlocks.back().isTerminated); !mSpirvCurrentFunctionBlocks.back().isTerminated);
return &mSpirvCurrentFunctionBlocks.back().body; return &mSpirvCurrentFunctionBlocks.back().body;
} }
spirv::IdRef getSpirvCurrentFunctionBlockId()
{
ASSERT(!mSpirvCurrentFunctionBlocks.empty() &&
!mSpirvCurrentFunctionBlocks.back().isTerminated);
return mSpirvCurrentFunctionBlocks.back().labelId;
}
bool isCurrentFunctionBlockTerminated() const bool isCurrentFunctionBlockTerminated() const
{ {
ASSERT(!mSpirvCurrentFunctionBlocks.empty()); ASSERT(!mSpirvCurrentFunctionBlocks.empty());
...@@ -298,6 +304,11 @@ class SPIRVBuilder : angle::NonCopyable ...@@ -298,6 +304,11 @@ class SPIRVBuilder : angle::NonCopyable
void addEntryPointInterfaceVariableId(spirv::IdRef id); void addEntryPointInterfaceVariableId(spirv::IdRef id);
void writePerVertexBuiltIns(const TType &type, spirv::IdRef typeId); void writePerVertexBuiltIns(const TType &type, spirv::IdRef typeId);
void writeInterfaceVariableDecorations(const TType &type, spirv::IdRef variableId); void writeInterfaceVariableDecorations(const TType &type, spirv::IdRef variableId);
void writeBranchConditional(spirv::IdRef conditionValue,
spirv::IdRef trueBlock,
spirv::IdRef falseBlock,
spirv::IdRef mergeBlock);
void writeBranchConditionalBlockEnd();
spirv::IdRef getBoolConstant(bool value); spirv::IdRef getBoolConstant(bool value);
spirv::IdRef getUintConstant(uint32_t value); spirv::IdRef getUintConstant(uint32_t value);
......
...@@ -229,6 +229,9 @@ class OutputSPIRVTraverser : public TIntermTraverser ...@@ -229,6 +229,9 @@ class OutputSPIRVTraverser : public TIntermTraverser
const spirv::IdRefList &parameters, const spirv::IdRefList &parameters,
spirv::IdRefList *extractedComponentsOut); spirv::IdRefList *extractedComponentsOut);
void startShortCircuit(TIntermBinary *node);
spirv::IdRef endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId);
spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId); spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
spirv::IdRef createAtomicBuiltIn(TIntermAggregate *node, spirv::IdRef resultTypeId); spirv::IdRef createAtomicBuiltIn(TIntermAggregate *node, spirv::IdRef resultTypeId);
...@@ -739,7 +742,7 @@ void OutputSPIRVTraverser::accessChainStore(NodeData *data, spirv::IdRef value) ...@@ -739,7 +742,7 @@ void OutputSPIRVTraverser::accessChainStore(NodeData *data, spirv::IdRef value)
// written. Use the final result as the value to be written to the vector. // written. Use the final result as the value to be written to the vector.
const spirv::IdRef result = mBuilder.getNewId({}); const spirv::IdRef result = mBuilder.getNewId({});
spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(), spirv::WriteVectorShuffle(mBuilder.getSpirvCurrentFunctionBlock(),
accessChain.postSwizzleTypeId, result, loadResult, value, accessChain.preSwizzleTypeId, result, loadResult, value,
swizzleList); swizzleList);
value = result; value = result;
} }
...@@ -1350,6 +1353,77 @@ void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node, ...@@ -1350,6 +1353,77 @@ void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
} }
} }
void OutputSPIRVTraverser::startShortCircuit(TIntermBinary *node)
{
// Emulate && and || as such:
//
// || => if (!left) result = right
// && => if ( left) result = right
//
// When this function is called, |left| has already been visited, so it creates the appropriate
// |if| construct in preparation for visiting |right|.
// Load |left| and replace the access chain with an rvalue that's the result.
spirv::IdRef typeId = getAccessChainTypeId(&mNodeData.back());
const spirv::IdRef left =
accessChainLoad(&mNodeData.back(), mBuilder.getDecorations(node->getLeft()->getType()));
nodeDataInitRValue(&mNodeData.back(), left, typeId);
// Keep the id of the block |left| was evaluated in.
mNodeData.back().idList.push_back(mBuilder.getSpirvCurrentFunctionBlockId());
// Two blocks necessary, one for the |if| block, and one for the merge block.
mBuilder.startConditional(2, false, false);
// Generate the branch instructions.
SpirvConditional *conditional = mBuilder.getCurrentConditional();
const spirv::IdRef mergeBlock = conditional->blockIds.back();
const spirv::IdRef ifBlock = conditional->blockIds.front();
spirv::IdRef trueBlock = node->getOp() == EOpLogicalAnd ? ifBlock : mergeBlock;
spirv::IdRef falseBlock = node->getOp() == EOpLogicalOr ? ifBlock : mergeBlock;
// Note that no logical not is necessary. For ||, the branch will target the merge block in the
// true case.
mBuilder.writeBranchConditional(left, trueBlock, falseBlock, mergeBlock);
}
spirv::IdRef OutputSPIRVTraverser::endShortCircuit(TIntermBinary *node, spirv::IdRef *typeId)
{
// Load the right hand side.
const spirv::IdRef right =
accessChainLoad(&mNodeData.back(), mBuilder.getDecorations(node->getRight()->getType()));
mNodeData.pop_back();
// Get the id of the block |right| is evaluated in.
const spirv::IdRef rightBlockId = mBuilder.getSpirvCurrentFunctionBlockId();
// And the cached id of the block |left| is evaluated in.
ASSERT(mNodeData.back().idList.size() == 1);
const spirv::IdRef leftBlockId = mNodeData.back().idList[0].id;
mNodeData.back().idList.clear();
// Move on to the merge block.
mBuilder.writeBranchConditionalBlockEnd();
// Pop from the conditional stack.
mBuilder.endConditional();
// Get the previously loaded result of the left hand side.
*typeId = getAccessChainTypeId(&mNodeData.back());
const spirv::IdRef left = mNodeData.back().baseId;
// Create an OpPhi instruction that selects either the |left| or |right| based on which block
// was traversed.
const spirv::IdRef result = mBuilder.getNewId(mBuilder.getDecorations(node->getType()));
spirv::WritePhi(
mBuilder.getSpirvCurrentFunctionBlock(), *typeId, result,
{spirv::PairIdRefIdRef{left, leftBlockId}, spirv::PairIdRefIdRef{right, rightBlockId}});
return result;
}
spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node, spirv::IdRef OutputSPIRVTraverser::createFunctionCall(TIntermAggregate *node,
spirv::IdRef resultTypeId) spirv::IdRef resultTypeId)
{ {
...@@ -1718,6 +1792,24 @@ bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node) ...@@ -1718,6 +1792,24 @@ bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
return true; return true;
} }
bool IsShortCircuitNeeded(TIntermBinary *node)
{
TOperator op = node->getOp();
// Short circuit is only necessary for && and ||.
if (op != EOpLogicalAnd && op != EOpLogicalOr)
{
return false;
}
// If the right hand side does not have side effects, short-circuiting is unnecessary.
// TODO: experiment with the performance of OpLogicalAnd/Or vs short-circuit based on the
// complexity of the right hand side expression. We could potentially only allow
// OpLogicalAnd/Or if the right hand side is a constant or an access chain and have more complex
// expressions be placed inside an if block. http://anglebug.com/4889
return node->getRight()->hasSideEffects();
}
bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node) bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
// Constants are expected to be folded. // Constants are expected to be folded.
...@@ -1736,6 +1828,27 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1736,6 +1828,27 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
return true; return true;
} }
if (IsShortCircuitNeeded(node))
{
// For && and ||, if short-circuiting behavior is needed, we need to emulate it with an
// |if| construct. At this point, the left-hand side is already evaluated, so we need to
// create an appropriate conditional on in-visit and visit the right-hand-side inside the
// conditional block. On post-visit, OpPhi is used to calculate the result.
if (visit == InVisit)
{
startShortCircuit(node);
return true;
}
spirv::IdRef typeId;
const spirv::IdRef result = endShortCircuit(node, &typeId);
// Replace the access chain with an rvalue that's the result.
nodeDataInitRValue(&mNodeData.back(), result, typeId);
return true;
}
if (visit == InVisit) if (visit == InVisit)
{ {
// Left child visited. Take the entry it created as the current node's. // Left child visited. Take the entry it created as the current node's.
...@@ -1939,6 +2052,21 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1939,6 +2052,21 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
writeBinaryOp = spirv::WriteSGreaterThanEqual; writeBinaryOp = spirv::WriteSGreaterThanEqual;
break; break;
case EOpLogicalOr:
ASSERT(!IsShortCircuitNeeded(node));
extendScalarToVector = false;
writeBinaryOp = spirv::WriteLogicalOr;
break;
case EOpLogicalXor:
extendScalarToVector = false;
writeBinaryOp = spirv::WriteLogicalNotEqual;
break;
case EOpLogicalAnd:
ASSERT(!IsShortCircuitNeeded(node));
extendScalarToVector = false;
writeBinaryOp = spirv::WriteLogicalAnd;
break;
case EOpBitShiftLeft: case EOpBitShiftLeft:
case EOpBitShiftLeftAssign: case EOpBitShiftLeftAssign:
writeBinaryOp = spirv::WriteShiftLeftLogical; writeBinaryOp = spirv::WriteShiftLeftLogical;
...@@ -2431,9 +2559,9 @@ bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node) ...@@ -2431,9 +2559,9 @@ bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
// Generate the branch instructions. // Generate the branch instructions.
SpirvConditional *conditional = mBuilder.getCurrentConditional(); SpirvConditional *conditional = mBuilder.getCurrentConditional();
spirv::IdRef mergeBlock = conditional->blockIds.back(); const spirv::IdRef mergeBlock = conditional->blockIds.back();
spirv::IdRef trueBlock = mergeBlock; spirv::IdRef trueBlock = mergeBlock;
spirv::IdRef falseBlock = mergeBlock; spirv::IdRef falseBlock = mergeBlock;
size_t nextBlockIndex = 0; size_t nextBlockIndex = 0;
if (node->getTrueBlock()) if (node->getTrueBlock())
...@@ -2445,32 +2573,13 @@ bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node) ...@@ -2445,32 +2573,13 @@ bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
falseBlock = conditional->blockIds[nextBlockIndex++]; falseBlock = conditional->blockIds[nextBlockIndex++];
} }
// Generate the following: mBuilder.writeBranchConditional(conditionValue, trueBlock, falseBlock, mergeBlock);
//
// OpSelectionMerge %mergeBlock None
// OpBranchConditional %conditionValue %trueBlock %falseBlock
//
spirv::WriteSelectionMerge(mBuilder.getSpirvCurrentFunctionBlock(), mergeBlock,
spv::SelectionControlMaskNone);
spirv::WriteBranchConditional(mBuilder.getSpirvCurrentFunctionBlock(), conditionValue,
trueBlock, falseBlock, {});
mBuilder.terminateCurrentFunctionBlock();
// Start the true or false block, whichever exists.
mBuilder.nextConditionalBlock();
return true; return true;
} }
// Otherwise move on to the next block, inserting a branch to the merge block at the end of each // Otherwise move on to the next block, inserting a branch to the merge block at the end of each
// block. // block.
spirv::IdRef mergeBlock = mBuilder.getCurrentConditional()->blockIds.back(); mBuilder.writeBranchConditionalBlockEnd();
ASSERT(!mBuilder.isCurrentFunctionBlockTerminated());
spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(), mergeBlock);
mBuilder.terminateCurrentFunctionBlock();
mBuilder.nextConditionalBlock();
// Pop from the conditional stack when done. // Pop from the conditional stack when done.
if (visit == PostVisit) if (visit == PostVisit)
......
...@@ -235,11 +235,8 @@ ANGLE_NO_DISCARD bool RotateAndFlipBuiltinVariable(TCompiler *compiler, ...@@ -235,11 +235,8 @@ ANGLE_NO_DISCARD bool RotateAndFlipBuiltinVariable(TCompiler *compiler,
TIntermBinary *plusPivot = new TIntermBinary(EOpAdd, inverseXY, pivot->deepCopy()); TIntermBinary *plusPivot = new TIntermBinary(EOpAdd, inverseXY, pivot->deepCopy());
// Create the corrected variable and copy the value of the original builtin. // Create the corrected variable and copy the value of the original builtin.
TIntermSequence sequence; TIntermBinary *assignment =
sequence.push_back(builtinRef->deepCopy()); new TIntermBinary(EOpAssign, flippedBuiltinRef, builtinRef->deepCopy());
TIntermAggregate *aggregate =
TIntermAggregate::CreateConstructor(builtin->getType(), &sequence);
TIntermBinary *assignment = new TIntermBinary(EOpAssign, flippedBuiltinRef, aggregate);
// Create an assignment to the replaced variable's .xy. // Create an assignment to the replaced variable's .xy.
TIntermSwizzle *correctedXY = TIntermSwizzle *correctedXY =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment