Commit be873929 by Shahbaz Youssefi Committed by Angle LUCI CQ

Vulkan: SPIR-V Gen: Handle constants and constructors

This change translates constants and constructors (minus type casting). With this change, shaders such as gl_Position = vec4(aposition, 0, 1); are translated correctly. Bug: angleproject:4889 Change-Id: I4463717cf880c6d05db179b98691d5cabc1a2d7c Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/2920192 Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org> Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarTim Van Patten <timvp@google.com>
parent c39e2a18
...@@ -51,7 +51,7 @@ spirv::IdRef SPIRVBuilder::getNewId() ...@@ -51,7 +51,7 @@ spirv::IdRef SPIRVBuilder::getNewId()
return newId; return newId;
} }
const SpirvTypeData &SPIRVBuilder::getTypeData(const TType &type, TLayoutBlockStorage blockStorage) SpirvType SPIRVBuilder::getSpirvType(const TType &type, TLayoutBlockStorage blockStorage) const
{ {
SpirvType spirvType; SpirvType spirvType;
spirvType.type = type.getBasicType(); spirvType.type = type.getBasicType();
...@@ -68,16 +68,13 @@ const SpirvTypeData &SPIRVBuilder::getTypeData(const TType &type, TLayoutBlockSt ...@@ -68,16 +68,13 @@ const SpirvTypeData &SPIRVBuilder::getTypeData(const TType &type, TLayoutBlockSt
spirvType.matrixPacking = EmpColumnMajor; spirvType.matrixPacking = EmpColumnMajor;
} }
const char *blockName = "";
if (type.getStruct() != nullptr) if (type.getStruct() != nullptr)
{ {
spirvType.block = type.getStruct(); spirvType.block = type.getStruct();
blockName = type.getStruct()->name().data();
} }
else if (type.isInterfaceBlock()) else if (type.isInterfaceBlock())
{ {
spirvType.block = type.getInterfaceBlock(); spirvType.block = type.getInterfaceBlock();
blockName = type.getInterfaceBlock()->name().data();
// Calculate the block storage from the interface block automatically. The fields inherit // Calculate the block storage from the interface block automatically. The fields inherit
// from this. Default to std140. // from this. Default to std140.
...@@ -94,6 +91,23 @@ const SpirvTypeData &SPIRVBuilder::getTypeData(const TType &type, TLayoutBlockSt ...@@ -94,6 +91,23 @@ const SpirvTypeData &SPIRVBuilder::getTypeData(const TType &type, TLayoutBlockSt
spirvType.blockStorage = EbsUnspecified; spirvType.blockStorage = EbsUnspecified;
} }
return spirvType;
}
const SpirvTypeData &SPIRVBuilder::getTypeData(const TType &type, TLayoutBlockStorage blockStorage)
{
SpirvType spirvType = getSpirvType(type, blockStorage);
const char *blockName = "";
if (type.getStruct() != nullptr)
{
blockName = type.getStruct()->name().data();
}
else if (type.isInterfaceBlock())
{
blockName = type.getInterfaceBlock()->name().data();
}
return getSpirvTypeData(spirvType, blockName); return getSpirvTypeData(spirvType, blockName);
} }
......
...@@ -169,6 +169,7 @@ class SPIRVBuilder : angle::NonCopyable ...@@ -169,6 +169,7 @@ class SPIRVBuilder : angle::NonCopyable
{} {}
spirv::IdRef getNewId(); spirv::IdRef getNewId();
SpirvType getSpirvType(const TType &type, TLayoutBlockStorage blockStorage) const;
const SpirvTypeData &getTypeData(const TType &type, TLayoutBlockStorage blockStorage); const SpirvTypeData &getTypeData(const TType &type, TLayoutBlockStorage blockStorage);
const SpirvTypeData &getSpirvTypeData(const SpirvType &type, const char *blockName); const SpirvTypeData &getSpirvTypeData(const SpirvType &type, const char *blockName);
spirv::IdRef getTypePointerId(spirv::IdRef typeId, spv::StorageClass storageClass); spirv::IdRef getTypePointerId(spirv::IdRef typeId, spv::StorageClass storageClass);
......
...@@ -121,6 +121,7 @@ class OutputSPIRVTraverser : public TIntermTraverser ...@@ -121,6 +121,7 @@ class OutputSPIRVTraverser : public TIntermTraverser
{ {
public: public:
OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions); OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions compileOptions);
~OutputSPIRVTraverser() override;
spirv::Blob getSpirv(); spirv::Blob getSpirv();
...@@ -173,6 +174,35 @@ class OutputSPIRVTraverser : public TIntermTraverser ...@@ -173,6 +174,35 @@ class OutputSPIRVTraverser : public TIntermTraverser
TLayoutBlockStorage blockStorage) const; TLayoutBlockStorage blockStorage) const;
void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const; void nodeDataInitRValue(NodeData *data, spirv::IdRef baseId, spirv::IdRef typeId) const;
spirv::IdRef createConstant(const TType &type,
TBasicType expectedBasicType,
const TConstantUnion *constUnion);
spirv::IdRef createConstructor(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createArrayOrStructConstructor(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorVectorFromScalar(const TType &type,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorVectorFromNonScalar(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorMatrixFromScalar(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorMatrixFromVectors(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
spirv::IdRef createConstructorMatrixFromMatrix(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters);
void extractComponents(TIntermAggregate *node,
size_t componentCount,
const spirv::IdRefList &parameters,
spirv::IdRefList *extractedComponentsOut);
ANGLE_MAYBE_UNUSED TCompiler *mCompiler; ANGLE_MAYBE_UNUSED TCompiler *mCompiler;
ANGLE_MAYBE_UNUSED ShCompileOptions mCompileOptions; ANGLE_MAYBE_UNUSED ShCompileOptions mCompileOptions;
...@@ -245,6 +275,11 @@ OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions ...@@ -245,6 +275,11 @@ OutputSPIRVTraverser::OutputSPIRVTraverser(TCompiler *compiler, ShCompileOptions
compiler->getNameMap()) compiler->getNameMap())
{} {}
OutputSPIRVTraverser::~OutputSPIRVTraverser()
{
ASSERT(mNodeData.empty());
}
void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data, void OutputSPIRVTraverser::nodeDataInitLValue(NodeData *data,
spirv::IdRef baseId, spirv::IdRef baseId,
spirv::IdRef typeId, spirv::IdRef typeId,
...@@ -627,8 +662,506 @@ spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data) ...@@ -627,8 +662,506 @@ spirv::IdRef OutputSPIRVTraverser::getAccessChainTypeId(NodeData *data)
return accessChain.preSwizzleTypeId; return accessChain.preSwizzleTypeId;
} }
spirv::IdRef OutputSPIRVTraverser::createConstant(const TType &type,
TBasicType expectedBasicType,
const TConstantUnion *constUnion)
{
const spirv::IdRef typeId = mBuilder.getTypeData(type, EbsUnspecified).id;
spirv::IdRefList componentIds;
if (type.getBasicType() == EbtStruct)
{
// If it's a struct constant, get the constant id for each field.
for (const TField *field : type.getStruct()->fields())
{
const TType *fieldType = field->type();
componentIds.push_back(
createConstant(*fieldType, fieldType->getBasicType(), constUnion));
constUnion += fieldType->getObjectSize();
}
}
else
{
// Otherwise get the constant id for each component.
const size_t size = type.getObjectSize();
ASSERT(expectedBasicType == EbtFloat || expectedBasicType == EbtInt ||
expectedBasicType == EbtUInt || expectedBasicType == EbtBool);
for (size_t component = 0; component < size; ++component, ++constUnion)
{
spirv::IdRef componentId;
// If the constant has a different type than expected, cast it right away.
TConstantUnion castConstant;
bool valid = castConstant.cast(expectedBasicType, *constUnion);
ASSERT(valid);
switch (castConstant.getType())
{
case EbtFloat:
componentId = mBuilder.getFloatConstant(castConstant.getFConst());
break;
case EbtInt:
componentId = mBuilder.getIntConstant(castConstant.getIConst());
break;
case EbtUInt:
componentId = mBuilder.getUintConstant(castConstant.getUConst());
break;
case EbtBool:
componentId = mBuilder.getBoolConstant(castConstant.getBConst());
break;
default:
UNREACHABLE();
}
componentIds.push_back(componentId);
}
}
// If this is a composite, create a composite constant from the components.
if (type.getBasicType() == EbtStruct || componentIds.size() > 1)
{
return mBuilder.getCompositeConstant(typeId, componentIds);
}
// Otherwise return the sole component.
ASSERT(componentIds.size() == 1);
return componentIds[0];
}
spirv::IdRef OutputSPIRVTraverser::createConstructor(TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
const TType &type = node->getType();
const TIntermSequence &arguments = *node->getSequence();
const TType &arg0Type = arguments[0]->getAsTyped()->getType();
// Constructors in GLSL can take various shapes, resulting in different translations to SPIR-V
// (in each case, if the parameter doesn't match the type being constructed, it must be cast):
//
// - float(f): This should translate to just f
// - vecN(f): This should translate to OpCompositeConstruct %vecN %f %f .. %f
// - vecN(v1.zy, v2.x): This can technically translate to OpCompositeConstruct with two ids; the
// results of v1.zy and v2.x. However, for simplicity it's easier to generate that
// instruction with three ids; the results of v1.z, v1.y and v2.x (see below where a matrix is
// used as parameter).
// - vecN(m): This takes N components from m in column-major order (for example, vec4
// constructed out of a 4x3 matrix would select components (0,0), (0,1), (0,2) and (1,0)).
// This translates to OpCompositeConstruct with the id of the individual components extracted
// from m.
// - matNxM(f): This creates a diagonal matrix. It generates N OpCompositeConstruct
// instructions for each column (which are vecM), followed by an OpCompositeConstruct that
// constructs the final result.
// - matNxM(m):
// * With m larger than NxM, this extracts a submatrix out of m. It generates
// OpCompositeExtracts for N columns of m, followed by an OpVectorShuffle (swizzle) if the
// rows of m are more than M. OpCompositeConstruct is used to construct the final result.
// * If m is not larger than NxM, an identity matrix is created and superimposed with m.
// OpCompositeExtract is used to extract each component of m (that is necessary), and
// together with the zero or one constants necessary used to create the columns (with
// OpCompositeConstruct). OpCompositeConstruct is used to construct the final result.
// - matNxM(v1.zy, v2.x, ...): Similarly to constructing a vector, a list of single components
// are extracted from the parameters, which are divided up and used to construct each column,
// which is finally constructed into the final result.
//
// Additionally, array and structs are constructed by OpCompositeConstruct followed by ids of
// each parameter which must enumerate every individual element / field.
if (type.isArray() || type.getStruct() != nullptr)
{
return createArrayOrStructConstructor(node, typeId, parameters);
}
if (type.isScalar())
{
// TODO: handle casting. http://anglebug.com/4889.
return parameters[0];
}
if (type.isVector())
{
if (arguments.size() == 1 && arg0Type.isScalar())
{
return createConstructorVectorFromScalar(node->getType(), typeId, parameters);
}
return createConstructorVectorFromNonScalar(node, typeId, parameters);
}
ASSERT(type.isMatrix());
if (arg0Type.isScalar())
{
return createConstructorMatrixFromScalar(node, typeId, parameters);
}
if (arg0Type.isMatrix())
{
return createConstructorMatrixFromMatrix(node, typeId, parameters);
}
return createConstructorMatrixFromVectors(node, typeId, parameters);
}
spirv::IdRef OutputSPIRVTraverser::createArrayOrStructConstructor(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
const spirv::IdRef result = mBuilder.getNewId();
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), typeId, result, parameters);
return result;
}
spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromScalar(
const TType &type,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
// vecN(f) translates to OpCompositeConstruct %vecN %f ... %f
ASSERT(parameters.size() == 1);
spirv::IdRefList replicatedParameter(type.getNominalSize(), parameters[0]);
const spirv::IdRef result = mBuilder.getNewId();
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), typeId, result,
replicatedParameter);
return result;
}
spirv::IdRef OutputSPIRVTraverser::createConstructorVectorFromNonScalar(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
// vecN(v1.zy, v2.x) translates to OpCompositeConstruct %vecN %v1.z %v1.y %v2.x
// vecN(m) translates to OpCompositeConstruct %vecN %m[0][0] %m[0][1] ...
spirv::IdRefList extractedComponents;
extractComponents(node, node->getType().getNominalSize(), parameters, &extractedComponents);
const spirv::IdRef result = mBuilder.getNewId();
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), typeId, result,
extractedComponents);
return result;
}
spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromScalar(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
// matNxM(f) translates to
//
// %c0 = OpCompositeConstruct %vecM %f %zero %zero ..
// %c1 = OpCompositeConstruct %vecM %zero %f %zero ..
// %c2 = OpCompositeConstruct %vecM %zero %zero %f ..
// ...
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
const TType &type = node->getType();
// TODO: handle casting. http://anglebug.com/4889.
const spirv::IdRef scalarId = parameters[0];
spirv::IdRef zeroId;
switch (type.getBasicType())
{
case EbtFloat:
zeroId = mBuilder.getFloatConstant(0);
break;
case EbtInt:
zeroId = mBuilder.getIntConstant(0);
break;
case EbtUInt:
zeroId = mBuilder.getUintConstant(0);
break;
case EbtBool:
zeroId = mBuilder.getBoolConstant(0);
break;
default:
UNREACHABLE();
}
spirv::IdRefList componentIds(type.getRows(), zeroId);
spirv::IdRefList columnIds;
SpirvType columnType = mBuilder.getSpirvType(type, EbsUnspecified);
columnType.secondarySize = 1;
const spirv::IdRef columnTypeId = mBuilder.getSpirvTypeData(columnType, "").id;
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
{
columnIds.push_back(mBuilder.getNewId());
// Place the scalar at the correct index (diagonal of the matrix, i.e. row == col).
componentIds[columnIndex] = scalarId;
if (columnIndex > 0)
{
componentIds[columnIndex - 1] = zeroId;
}
// Create the column.
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), columnTypeId, columnIds.back(),
componentIds);
}
// Create the matrix out of the columns.
const spirv::IdRef result = mBuilder.getNewId();
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), typeId, result, columnIds);
return result;
}
spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromVectors(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
// matNxM(v1.zy, v2.x, ...) translates to:
//
// %c0 = OpCompositeConstruct %vecM %v1.z %v1.y %v2.x ..
// ...
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
const TType &type = node->getType();
spirv::IdRefList extractedComponents;
extractComponents(node, type.getCols() * type.getRows(), parameters, &extractedComponents);
spirv::IdRefList columnIds;
SpirvType columnType = mBuilder.getSpirvType(type, EbsUnspecified);
columnType.secondarySize = 1;
const spirv::IdRef columnTypeId = mBuilder.getSpirvTypeData(columnType, "").id;
// Chunk up the extracted components by column and construct intermediary vectors.
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
{
columnIds.push_back(mBuilder.getNewId());
auto componentsStart = extractedComponents.begin() + columnIndex * type.getRows();
const spirv::IdRefList componentIds(componentsStart, componentsStart + type.getRows());
// Create the column.
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), columnTypeId, columnIds.back(),
componentIds);
}
const spirv::IdRef result = mBuilder.getNewId();
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), typeId, result, columnIds);
return result;
}
spirv::IdRef OutputSPIRVTraverser::createConstructorMatrixFromMatrix(
TIntermAggregate *node,
spirv::IdRef typeId,
const spirv::IdRefList &parameters)
{
// matNxM(m) translates to:
//
// - If m is SxR where S>=N and R>=M:
//
// %c0 = OpCompositeExtract %vecR %m 0
// %c1 = OpCompositeExtract %vecR %m 1
// ...
// // If R (column size of m) != M, OpVectorShuffle to extract M components out of %ci.
// ...
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 ...
//
// - Otherwise, an identity matrix is created and super imposed by m:
//
// %c0 = OpCompositeConstruct %vecM %m[0][0] %m[0][1] %0 %0
// %c1 = OpCompositeConstruct %vecM %m[1][0] %m[1][1] %0 %0
// %c2 = OpCompositeConstruct %vecM %m[2][0] %m[2][1] %1 %0
// %c3 = OpCompositeConstruct %vecM %0 %0 %0 %1
// %m = OpCompositeConstruct %matNxM %c0 %c1 %c2 %c3
const TType &type = node->getType();
const TType &parameterType = (*node->getSequence())[0]->getAsTyped()->getType();
// TODO: handle casting. http://anglebug.com/4889.
ASSERT(parameters.size() == 1);
spirv::IdRefList columnIds;
SpirvType columnType = mBuilder.getSpirvType(type, EbsUnspecified);
columnType.secondarySize = 1;
const spirv::IdRef columnTypeId = mBuilder.getSpirvTypeData(columnType, "").id;
if (parameterType.getCols() >= type.getCols() && parameterType.getRows() >= type.getRows())
{
// If the parameter is a larger matrix than the constructor type, extract the columns
// directly and potentially swizzle them.
SpirvType paramColumnType = mBuilder.getSpirvType(parameterType, EbsUnspecified);
paramColumnType.secondarySize = 1;
const spirv::IdRef paramColumnTypeId = mBuilder.getSpirvTypeData(paramColumnType, "").id;
const bool needsSwizzle = parameterType.getRows() > type.getRows();
spirv::LiteralIntegerList swizzle = {spirv::LiteralInteger(0), spirv::LiteralInteger(1),
spirv::LiteralInteger(2), spirv::LiteralInteger(3)};
swizzle.resize(type.getRows());
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
{
// Extract the column.
const spirv::IdRef parameterColumnId = mBuilder.getNewId();
spirv::WriteCompositeExtract(mBuilder.getSpirvFunctions(), paramColumnTypeId,
parameterColumnId, parameters[0],
{spirv::LiteralInteger(columnIndex)});
// If the column has too many components, select the appropriate number of components.
spirv::IdRef constructorColumnId = parameterColumnId;
if (needsSwizzle)
{
constructorColumnId = mBuilder.getNewId();
spirv::WriteVectorShuffle(mBuilder.getSpirvFunctions(), columnTypeId,
constructorColumnId, parameterColumnId, parameterColumnId,
swizzle);
}
columnIds.push_back(constructorColumnId);
}
}
else
{
// Otherwise create an identity matrix and fill in the components that can be taken from the
// given parameter.
SpirvType paramComponentType = mBuilder.getSpirvType(parameterType, EbsUnspecified);
paramComponentType.primarySize = 1;
paramComponentType.secondarySize = 1;
const spirv::IdRef paramComponentTypeId =
mBuilder.getSpirvTypeData(paramComponentType, "").id;
for (int columnIndex = 0; columnIndex < type.getCols(); ++columnIndex)
{
spirv::IdRefList componentIds;
for (int componentIndex = 0; componentIndex < type.getRows(); ++componentIndex)
{
// Take the component from the constructor parameter if possible.
spirv::IdRef componentId;
if (componentIndex < parameterType.getRows())
{
componentId = mBuilder.getNewId();
spirv::WriteCompositeExtract(mBuilder.getSpirvFunctions(), paramComponentTypeId,
componentId, parameters[0],
{spirv::LiteralInteger(columnIndex),
spirv::LiteralInteger(componentIndex)});
}
else
{
const bool isOnDiagonal = columnIndex == componentIndex;
switch (type.getBasicType())
{
case EbtFloat:
componentId = mBuilder.getFloatConstant(isOnDiagonal ? 0.0f : 1.0f);
break;
case EbtInt:
componentId = mBuilder.getIntConstant(isOnDiagonal ? 0 : 1);
break;
case EbtUInt:
componentId = mBuilder.getUintConstant(isOnDiagonal ? 0 : 1);
break;
case EbtBool:
componentId = mBuilder.getBoolConstant(isOnDiagonal);
break;
default:
UNREACHABLE();
}
}
componentIds.push_back(componentId);
}
// Create the column vector.
columnIds.push_back(mBuilder.getNewId());
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), columnTypeId,
columnIds.back(), componentIds);
}
}
const spirv::IdRef result = mBuilder.getNewId();
spirv::WriteCompositeConstruct(mBuilder.getSpirvFunctions(), typeId, result, columnIds);
return result;
}
void OutputSPIRVTraverser::extractComponents(TIntermAggregate *node,
size_t componentCount,
const spirv::IdRefList &parameters,
spirv::IdRefList *extractedComponentsOut)
{
// A helper function that takes the list of parameters passed to a constructor (which may have
// more components than necessary) and extracts the first componentCount components.
const TIntermSequence &arguments = *node->getSequence();
// TODO: handle casting. http://anglebug.com/4889.
ASSERT(arguments.size() == parameters.size());
for (size_t argumentIndex = 0;
argumentIndex < arguments.size() && extractedComponentsOut->size() < componentCount;
++argumentIndex)
{
const TType &argumentType = arguments[argumentIndex]->getAsTyped()->getType();
const spirv::IdRef parameterId = parameters[argumentIndex];
if (argumentType.isScalar())
{
// For scalar parameters, there's nothing to do.
extractedComponentsOut->push_back(parameterId);
continue;
}
if (argumentType.isVector())
{
SpirvType componentType = mBuilder.getSpirvType(argumentType, EbsUnspecified);
componentType.primarySize = 1;
const spirv::IdRef componentTypeId = mBuilder.getSpirvTypeData(componentType, "").id;
// For vector parameters, take components out of the vector one by one.
for (int componentIndex = 0; componentIndex < argumentType.getNominalSize() &&
extractedComponentsOut->size() < componentCount;
++componentIndex)
{
const spirv::IdRef componentId = mBuilder.getNewId();
spirv::WriteCompositeExtract(mBuilder.getSpirvFunctions(), componentTypeId,
componentId, parameterId,
{spirv::LiteralInteger(componentIndex)});
extractedComponentsOut->push_back(componentId);
}
continue;
}
ASSERT(argumentType.isMatrix());
SpirvType componentType = mBuilder.getSpirvType(argumentType, EbsUnspecified);
componentType.primarySize = 1;
componentType.secondarySize = 1;
const spirv::IdRef componentTypeId = mBuilder.getSpirvTypeData(componentType, "").id;
// For matrix parameters, take components out of the matrix one by one in column-major
// order.
for (int columnIndex = 0; columnIndex < argumentType.getCols() &&
extractedComponentsOut->size() < componentCount;
++columnIndex)
{
for (int componentIndex = 0; componentIndex < argumentType.getRows() &&
extractedComponentsOut->size() < componentCount;
++componentIndex)
{
const spirv::IdRef componentId = mBuilder.getNewId();
spirv::WriteCompositeExtract(
mBuilder.getSpirvFunctions(), componentTypeId, componentId, parameterId,
{spirv::LiteralInteger(columnIndex), spirv::LiteralInteger(componentIndex)});
extractedComponentsOut->push_back(componentId);
}
}
}
}
void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node) void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
{ {
// Constants are expected to be folded.
ASSERT(!node->hasConstantValue());
mNodeData.emplace_back(); mNodeData.emplace_back();
// The symbol is either: // The symbol is either:
...@@ -672,12 +1205,58 @@ void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node) ...@@ -672,12 +1205,58 @@ void OutputSPIRVTraverser::visitSymbol(TIntermSymbol *node)
void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node) void OutputSPIRVTraverser::visitConstantUnion(TIntermConstantUnion *node)
{ {
// TODO: http://anglebug.com/4889 mNodeData.emplace_back();
UNIMPLEMENTED();
const TType &type = node->getType();
// Find out the expected type for this constant, so it can be cast right away and not need an
// instruction to do that.
TIntermNode *parent = getParentNode();
const size_t childIndex = getParentChildIndex();
TBasicType expectedBasicType = type.getBasicType();
if (parent->getAsAggregate())
{
TIntermAggregate *parentAggregate = parent->getAsAggregate();
// There are three possibilities:
//
// - It's a struct constructor: The basic type must match that of the corresponding field of
// the struct.
// - It's a non struct constructor: The basic type must match that of the the type being
// constructed.
// - It's a function call: The basic type must match that of the corresponding argument.
if (parentAggregate->isConstructor())
{
const TStructure *structure = parentAggregate->getType().getStruct();
if (structure != nullptr)
{
expectedBasicType = structure->fields()[childIndex]->type()->getBasicType();
}
else
{
expectedBasicType = parentAggregate->getType().getBasicType();
}
}
else
{
expectedBasicType =
parentAggregate->getFunction()->getParam(childIndex)->getType().getBasicType();
}
}
// TODO: other node types such as binary, ternary etc. http://anglebug.com/4889
const spirv::IdRef typeId = mBuilder.getTypeData(type, EbsUnspecified).id;
const spirv::IdRef constId = createConstant(type, expectedBasicType, node->getConstantValue());
nodeDataInitRValue(&mNodeData.back(), constId, typeId);
} }
bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node) bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
{ {
// Constants are expected to be folded.
ASSERT(!node->hasConstantValue());
if (visit == PreVisit) if (visit == PreVisit)
{ {
// Don't add an entry to the stack. The child will create one, which we won't pop. // Don't add an entry to the stack. The child will create one, which we won't pop.
...@@ -714,6 +1293,9 @@ bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node) ...@@ -714,6 +1293,9 @@ bool OutputSPIRVTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node) bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
// Constants are expected to be folded.
ASSERT(!node->hasConstantValue());
if (visit == PreVisit) if (visit == PreVisit)
{ {
// Don't add an entry to the stack. The left child will create one, which we won't pop. // Don't add an entry to the stack. The left child will create one, which we won't pop.
...@@ -726,7 +1308,7 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -726,7 +1308,7 @@ bool OutputSPIRVTraverser::visitBinary(Visit visit, TIntermBinary *node)
ASSERT(mNodeData.size() >= 1); ASSERT(mNodeData.size() >= 1);
// As an optimization, if the index is EOpIndexDirect*, take the constant index directly and // As an optimization, if the index is EOpIndexDirect*, take the constant index directly and
// add it to the access chain as constant. // add it to the access chain as literal.
switch (node->getOp()) switch (node->getOp())
{ {
case EOpIndexDirect: case EOpIndexDirect:
...@@ -859,6 +1441,13 @@ bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node) ...@@ -859,6 +1441,13 @@ bool OutputSPIRVTraverser::visitBlock(Visit visit, TIntermBlock *node)
mNodeData.pop_back(); mNodeData.pop_back();
} }
if (visit != PostVisit)
{
return true;
}
mNodeData.pop_back();
return true; return true;
} }
...@@ -938,6 +1527,50 @@ void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node ...@@ -938,6 +1527,50 @@ void OutputSPIRVTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node
bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node) bool OutputSPIRVTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{ {
// Constants are expected to be folded.
ASSERT(!node->hasConstantValue());
if (visit == PreVisit)
{
mNodeData.emplace_back();
return true;
}
// Keep the parameters on the stack. If a function call contains out or inout parameters, we
// need to know the access chains for the eventual write back to them.
if (visit == InVisit)
{
return true;
}
// Expect to have accumulated as many parameters as the node requires.
size_t parameterCount = node->getChildCount();
ASSERT(mNodeData.size() > parameterCount);
const spirv::IdRef typeId = mBuilder.getTypeData(node->getType(), EbsUnspecified).id;
if (node->isConstructor())
{
// Construct a value out of the accumulated parameters.
spirv::IdRefList parameters;
for (size_t paramIndex = 0; paramIndex < parameterCount; ++paramIndex)
{
// Take each constructor argument that is visited and evaluate it as rvalue
NodeData &param = mNodeData[mNodeData.size() - parameterCount + paramIndex];
const spirv::IdRef paramValue = accessChainLoad(&param);
// TODO: handle mismatching types. http://anglebug.com/6000
parameters.push_back(paramValue);
}
mNodeData.resize(mNodeData.size() - parameterCount);
const spirv::IdRef result = createConstructor(node, typeId, parameters);
nodeDataInitRValue(&mNodeData.back(), result, typeId);
return true;
}
// TODO: http://anglebug.com/4889 // TODO: http://anglebug.com/4889
UNIMPLEMENTED(); UNIMPLEMENTED();
......
...@@ -27,6 +27,8 @@ void TIntermTraverser::traverse(T *node) ...@@ -27,6 +27,8 @@ void TIntermTraverser::traverse(T *node)
bool visit = true; bool visit = true;
mCurrentChildIndex = 0;
// Visit the node before children if pre-visiting. // Visit the node before children if pre-visiting.
if (preVisit) if (preVisit)
visit = node->visit(PreVisit, this); visit = node->visit(PreVisit, this);
...@@ -38,6 +40,7 @@ void TIntermTraverser::traverse(T *node) ...@@ -38,6 +40,7 @@ void TIntermTraverser::traverse(T *node)
while (childIndex < childCount && visit) while (childIndex < childCount && visit)
{ {
mCurrentChildIndex = childIndex;
node->getChildNode(childIndex)->traverse(this); node->getChildNode(childIndex)->traverse(this);
if (inVisit && childIndex != childCount - 1) if (inVisit && childIndex != childCount - 1)
{ {
...@@ -217,7 +220,8 @@ TIntermTraverser::TIntermTraverser(bool preVisit, ...@@ -217,7 +220,8 @@ TIntermTraverser::TIntermTraverser(bool preVisit,
mMaxDepth(0), mMaxDepth(0),
mMaxAllowedDepth(std::numeric_limits<int>::max()), mMaxAllowedDepth(std::numeric_limits<int>::max()),
mInGlobalScope(true), mInGlobalScope(true),
mSymbolTable(symbolTable) mSymbolTable(symbolTable),
mCurrentChildIndex(0)
{ {
// Only enabling inVisit is not supported. // Only enabling inVisit is not supported.
ASSERT(!(inVisit && !preVisit && !postVisit)); ASSERT(!(inVisit && !preVisit && !postVisit));
......
...@@ -149,6 +149,11 @@ class TIntermTraverser : angle::NonCopyable ...@@ -149,6 +149,11 @@ class TIntermTraverser : angle::NonCopyable
return nullptr; return nullptr;
} }
// Returns what child index is currently being visited. For example when visiting the children
// of an aggregate, it can be used to find out which argument of the parent (aggregate) node
// they correspond to.
size_t getParentChildIndex() const { return mCurrentChildIndex; }
const TIntermBlock *getParentBlock() const; const TIntermBlock *getParentBlock() const;
TIntermNode *getRootNode() const TIntermNode *getRootNode() const
...@@ -287,6 +292,8 @@ class TIntermTraverser : angle::NonCopyable ...@@ -287,6 +292,8 @@ class TIntermTraverser : angle::NonCopyable
// All the nodes from root to the current node during traversing. // All the nodes from root to the current node during traversing.
TVector<TIntermNode *> mPath; TVector<TIntermNode *> mPath;
// The current child of parent being traversed.
size_t mCurrentChildIndex;
// All the code blocks from the root to the current node's parent during traversal. // All the code blocks from the root to the current node's parent during traversal.
std::vector<ParentBlock> mParentBlockStack; std::vector<ParentBlock> mParentBlockStack;
......
...@@ -1053,7 +1053,7 @@ TEST_P(BufferStorageTestES3, StorageBufferMapBufferOES) ...@@ -1053,7 +1053,7 @@ TEST_P(BufferStorageTestES3, StorageBufferMapBufferOES)
ANGLE_INSTANTIATE_TEST_ES2(BufferDataTest); ANGLE_INSTANTIATE_TEST_ES2(BufferDataTest);
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(BufferDataTestES3); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(BufferDataTestES3);
ANGLE_INSTANTIATE_TEST_ES3(BufferDataTestES3); ANGLE_INSTANTIATE_TEST_ES3_AND(BufferDataTestES3, WithDirectSPIRVGeneration(ES3_VULKAN()));
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(BufferStorageTestES3); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(BufferStorageTestES3);
ANGLE_INSTANTIATE_TEST_ES3(BufferStorageTestES3); ANGLE_INSTANTIATE_TEST_ES3(BufferStorageTestES3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment