Commit 77831a6d by Shahbaz Youssefi Committed by Angle LUCI CQ

Vulkan: SPIR-V Gen: If-else blocks

This change implements if-else blocks as well as a few simple binary operations (equality, less-than etc). It builds on prior work to generate the function in separate blocks and introduces a "conditionals" stack to support nesting of if-else blocks, switches and loops. Bug: angleproject:4889 Change-Id: If7694000487811837ed5946753568b41d67199f0 Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/2929660 Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarTim Van Patten <timvp@google.com>
parent 934fc56e
......@@ -843,16 +843,63 @@ spirv::IdRef SPIRVBuilder::declareVariable(spirv::IdRef typeId,
return variableId;
}
void SPIRVBuilder::startConditional(size_t blockCount, bool isContinuable, bool isBreakable)
{
mConditionalStack.emplace_back();
SpirvConditional &conditional = mConditionalStack.back();
// Create the requested number of block ids.
conditional.blockIds.resize(blockCount);
for (spirv::IdRef &blockId : conditional.blockIds)
{
blockId = getNewId();
}
conditional.isContinuable = isContinuable;
conditional.isBreakable = isBreakable;
// Don't automatically start the next block. The caller needs to generate instructions based on
// the ids that were just generated above.
}
void SPIRVBuilder::nextConditionalBlock()
{
ASSERT(!mConditionalStack.empty());
SpirvConditional &conditional = mConditionalStack.back();
ASSERT(conditional.nextBlockToWrite < conditional.blockIds.size());
spirv::IdRef blockId = conditional.blockIds[conditional.nextBlockToWrite++];
// The previous block must have properly terminated.
ASSERT(isCurrentFunctionBlockTerminated());
// Generate a new block.
mSpirvCurrentFunctionBlocks.emplace_back();
mSpirvCurrentFunctionBlocks.back().labelId = blockId;
}
void SPIRVBuilder::endConditional()
{
ASSERT(!mConditionalStack.empty());
// No blocks should be left.
ASSERT(mConditionalStack.back().nextBlockToWrite == mConditionalStack.back().blockIds.size());
mConditionalStack.pop_back();
}
uint32_t SPIRVBuilder::nextUnusedBinding()
{
return mNextUnusedBinding++;
}
uint32_t SPIRVBuilder::nextUnusedInputLocation(uint32_t consumedCount)
{
uint32_t nextUnused = mNextUnusedInputLocation;
mNextUnusedInputLocation += consumedCount;
return nextUnused;
}
uint32_t SPIRVBuilder::nextUnusedOutputLocation(uint32_t consumedCount)
{
uint32_t nextUnused = mNextUnusedOutputLocation;
......
......@@ -181,6 +181,32 @@ struct SpirvBlock
bool isTerminated = false;
};
// Conditional code, constituting ifs, switches and loops.
struct SpirvConditional
{
// The id of blocks that make up the conditional.
//
// - For if, there are three blocks: the then, else and merge blocks
// - For loops, there are four blocks: the condition, body, continue and merge blocks
// - For switch, there are a number of blocks based on the cases.
//
// In all cases, the merge block is the last block in this list. When the conditional is done
// with, that's the block that will be made "current" and future instructions written to. The
// merge block is also the branch target of "break" instructions.
//
// For loops, the continue target block is the one before last block in this list.
std::vector<spirv::IdRef> blockIds;
// Up to which block is already generated. Used by nextConditionalBlock() to generate a block
// and give it an id pre-determined in blockIds.
size_t nextBlockToWrite = 0;
// Used to determine if continue will affect this (i.e. it's a loop).
bool isContinuable = false;
// Used to determine if break will affect this (i.e. it's a loop or switch).
bool isBreakable = false;
};
// Helper class to construct SPIR-V
class SPIRVBuilder : angle::NonCopyable
{
......@@ -225,6 +251,7 @@ class SPIRVBuilder : angle::NonCopyable
ASSERT(!mSpirvCurrentFunctionBlocks.empty());
mSpirvCurrentFunctionBlocks.back().isTerminated = true;
}
SpirvConditional *getCurrentConditional() { return &mConditionalStack.back(); }
void addCapability(spv::Capability capability);
void addExecutionMode(spv::ExecutionMode executionMode);
......@@ -255,6 +282,11 @@ class SPIRVBuilder : angle::NonCopyable
spirv::IdRef *initializerId,
const char *name);
// Helpers for conditionals.
void startConditional(size_t blockCount, bool isContinuable, bool isBreakable);
void nextConditionalBlock();
void endConditional();
// TODO: remove name hashing once translation through glslang is removed. That is necessary to
// avoid name collision between ANGLE's internal symbols and user-defined ones when compiling
// the generated GLSL, but is irrelevant when generating SPIR-V directly. Currently, the SPIR-V
......@@ -347,6 +379,13 @@ class SPIRVBuilder : angle::NonCopyable
// List of function types that are already defined.
angle::HashMap<SpirvIdAndIdList, spirv::IdRef, SpirvIdAndIdListHash> mFunctionTypeIdMap;
// Stack of conditionals. When an if, loop or switch is visited, a new conditional scope is
// added. When the conditional construct is entirely visited, it's popped. As the blocks of
// the conditional constructs are visited, ids are consumed from the top of the stack. When
// break or continue is visited, the stack is traversed backwards until a loop or switch is
// found.
std::vector<SpirvConditional> mConditionalStack;
// name hashing.
ShHashFunction64 mHashFunction;
NameMap &mNameMap;
......
......@@ -1217,7 +1217,7 @@ void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
// Find out the expected type for this constant, so it can be cast right away and not need an
// instruction to do that.
TIntermNode *parent = getParentNode();
const size_t childIndex = getParentChildIndex();
const size_t childIndex = getParentChildIndex(PreVisit);
TBasicType expectedBasicType = type.getBasicType();
if (parent->getAsAggregate())
......@@ -1347,13 +1347,23 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
NodeData &left = mNodeData.back();
spirv::IdRef typeId;
const TBasicType leftBasicType = node->getLeft()->getType().getBasicType();
const bool isFloat = leftBasicType == EbtFloat || leftBasicType == EbtDouble;
const bool isUnsigned = leftBasicType == EbtUInt;
const bool isBool = leftBasicType == EbtBool;
using WriteBinaryOp =
void (*)(spirv::Blob * blob, spirv::IdResultType idResultType, spirv::IdResult idResult,
spirv::IdRef operand1, spirv::IdRef operand2);
WriteBinaryOp writeBinaryOp = nullptr;
switch (node->getOp())
{
case EOpIndexDirect:
case EOpIndexDirectStruct:
case EOpIndexDirectInterfaceBlock:
UNREACHABLE();
break;
return true;
case EOpIndexIndirect:
typeId = mBuilder.getTypeData(node->getType(), left.accessChain.baseBlockStorage).id;
if (!node->getLeft()->getType().isArray() && node->getLeft()->getType().isVector())
......@@ -1364,18 +1374,91 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
accessChainPush(&left, rightValue, typeId);
}
break;
return true;
case EOpAssign:
// Store into the access chain. Since the result of the (a = b) expression is b, change
// the access chain to an unindexed rvalue which is |rightValue|.
accessChainStore(&left, rightValue);
nodeDataInitRValue(&left, rightValue, rightTypeId);
break;
case EOpEqual:
case EOpEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdEqual;
else if (isBool)
writeBinaryOp = spirv::WriteLogicalEqual;
else
writeBinaryOp = spirv::WriteIEqual;
break;
case EOpNotEqual:
case EOpNotEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFUnordNotEqual;
else if (isBool)
writeBinaryOp = spirv::WriteLogicalNotEqual;
else
writeBinaryOp = spirv::WriteINotEqual;
break;
case EOpLessThan:
case EOpLessThanComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdLessThan;
else if (isUnsigned)
writeBinaryOp = spirv::WriteULessThan;
else
writeBinaryOp = spirv::WriteSLessThan;
break;
case EOpGreaterThan:
case EOpGreaterThanComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdGreaterThan;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUGreaterThan;
else
writeBinaryOp = spirv::WriteSGreaterThan;
break;
case EOpLessThanEqual:
case EOpLessThanEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdLessThanEqual;
else if (isUnsigned)
writeBinaryOp = spirv::WriteULessThanEqual;
else
writeBinaryOp = spirv::WriteSLessThanEqual;
break;
case EOpGreaterThanEqual:
case EOpGreaterThanEqualComponentWise:
if (isFloat)
writeBinaryOp = spirv::WriteFOrdGreaterThanEqual;
else if (isUnsigned)
writeBinaryOp = spirv::WriteUGreaterThanEqual;
else
writeBinaryOp = spirv::WriteSGreaterThanEqual;
break;
default:
UNIMPLEMENTED();
break;
}
if (writeBinaryOp)
{
// Load the left value.
const spirv::IdRef leftValue = accessChainLoad(&left);
ASSERT(!typeId.valid());
typeId = mBuilder.getTypeData(node->getType(), EbsUnspecified).id;
// Write the operation that combines the left and right values.
spirv::IdRef result = mBuilder.getNewId();
writeBinaryOp(mBuilder.getSpirvCurrentFunctionBlock(), typeId, result, leftValue,
rightValue);
// Replace the access chain with an rvalue that's the result.
nodeDataInitRValue(&left, result, typeId);
}
return true;
}
......@@ -1397,8 +1480,73 @@ bool OutputSPIRVTraverser::visitTernary(Visit visit, TIntermTernary *node)
bool OutputSPIRVTraverser::visitIfElse(Visit visit, TIntermIfElse *node)
{
// TODO: http://anglebug.com/4889
UNIMPLEMENTED();
if (visit == PreVisit)
{
// Don't add an entry to the stack. The condition will create one, which we won't pop.
return true;
}
size_t lastChildIndex = getLastTraversedChildIndex(visit);
// If the condition was just visited, evaluate it and create the branch instructions.
if (lastChildIndex == 0)
{
const spirv::IdRef conditionValue = accessChainLoad(&mNodeData.back());
// Create a conditional with maximum 3 blocks, one for the true block (if any), one for the
// else block (if any), and one for the merge block. getChildCount() works here as it
// produces an identical count.
mBuilder.startConditional(node->getChildCount(), false, false);
// Generate the branch instructions.
SpirvConditional *conditional = mBuilder.getCurrentConditional();
spirv::IdRef mergeBlock = conditional->blockIds.back();
spirv::IdRef trueBlock = mergeBlock;
spirv::IdRef falseBlock = mergeBlock;
size_t nextBlockIndex = 0;
if (node->getTrueBlock())
{
trueBlock = conditional->blockIds[nextBlockIndex++];
}
if (node->getFalseBlock())
{
falseBlock = conditional->blockIds[nextBlockIndex++];
}
// Generate the following:
//
// OpSelectionMerge %mergeBlock None
// OpBranchConditional %conditionValue %trueBlock %falseBlock
//
spirv::WriteSelectionMerge(mBuilder.getSpirvCurrentFunctionBlock(), mergeBlock,
spv::SelectionControlMaskNone);
spirv::WriteBranchConditional(mBuilder.getSpirvCurrentFunctionBlock(), conditionValue,
trueBlock, falseBlock, {});
mBuilder.terminateCurrentFunctionBlock();
// Start the true or false block, whichever exists.
mBuilder.nextConditionalBlock();
return true;
}
// Otherwise move on to the next block, inserting a branch to the merge block at the end of each
// block.
spirv::IdRef mergeBlock = mBuilder.getCurrentConditional()->blockIds.back();
ASSERT(!mBuilder.isCurrentFunctionBlockTerminated());
spirv::WriteBranch(mBuilder.getSpirvCurrentFunctionBlock(), mergeBlock);
mBuilder.terminateCurrentFunctionBlock();
mBuilder.nextConditionalBlock();
// Pop from the conditional stack when done.
if (visit == PostVisit)
{
mBuilder.endConditional();
}
return true;
}
......
......@@ -42,6 +42,8 @@ void TIntermTraverser::traverse(T *node)
{
mCurrentChildIndex = childIndex;
node->getChildNode(childIndex)->traverse(this);
mCurrentChildIndex = childIndex;
if (inVisit && childIndex != childCount - 1)
{
visit = node->visit(InVisit, this);
......
......@@ -151,8 +151,19 @@ class TIntermTraverser : angle::NonCopyable
// Returns what child index is currently being visited. For example when visiting the children
// of an aggregate, it can be used to find out which argument of the parent (aggregate) node
// they correspond to.
size_t getParentChildIndex() const { return mCurrentChildIndex; }
// they correspond to. Only valid in the PreVisit call of the child.
size_t getParentChildIndex(Visit visit) const
{
ASSERT(visit == PreVisit);
return mCurrentChildIndex;
}
// Returns what child index has just been processed. Only valid in the InVisit and PostVisit
// calls of the parent node.
size_t getLastTraversedChildIndex(Visit visit) const
{
ASSERT(visit != PreVisit);
return mCurrentChildIndex;
}
const TIntermBlock *getParentBlock() const;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment