Commit 5fc501ff by John Kessenich Committed by GitHub

Merge pull request #520 from amdrexu/bugfix

SPV: OpGroupBroadcast is unable to handle vector operand.
parents f38978ed b707205b
...@@ -156,9 +156,7 @@ protected: ...@@ -156,9 +156,7 @@ protected:
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize); spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy); spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy); spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
#ifdef AMD_EXTENSIONS spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, std::vector<spv::Id>& operands);
spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, spv::Id operand);
#endif
spv::Id createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy); spv::Id createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
spv::Id createNoArgOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId); spv::Id createNoArgOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId);
spv::Id getSymbolId(const glslang::TIntermSymbol* node); spv::Id getSymbolId(const glslang::TIntermSymbol* node);
...@@ -4029,6 +4027,8 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op ...@@ -4029,6 +4027,8 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op
case glslang::EOpReadInvocation: case glslang::EOpReadInvocation:
opCode = spv::OpGroupBroadcast; opCode = spv::OpGroupBroadcast;
if (builder.isVectorType(typeId))
return CreateInvocationsVectorOperation(opCode, typeId, operands);
break; break;
case glslang::EOpReadFirstInvocation: case glslang::EOpReadFirstInvocation:
opCode = spv::OpSubgroupFirstInvocationKHR; opCode = spv::OpSubgroupFirstInvocationKHR;
...@@ -4084,7 +4084,7 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op ...@@ -4084,7 +4084,7 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op
} }
if (builder.isVectorType(typeId)) if (builder.isVectorType(typeId))
return CreateInvocationsVectorOperation(opCode, typeId, operands[0]); return CreateInvocationsVectorOperation(opCode, typeId, operands);
break; break;
case glslang::EOpMinInvocationsNonUniform: case glslang::EOpMinInvocationsNonUniform:
...@@ -4118,7 +4118,7 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op ...@@ -4118,7 +4118,7 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op
} }
if (builder.isVectorType(typeId)) if (builder.isVectorType(typeId))
return CreateInvocationsVectorOperation(opCode, typeId, operands[0]); return CreateInvocationsVectorOperation(opCode, typeId, operands);
break; break;
#endif #endif
...@@ -4131,16 +4131,21 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op ...@@ -4131,16 +4131,21 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op
return builder.createOp(opCode, typeId, spvGroupOperands); return builder.createOp(opCode, typeId, spvGroupOperands);
} }
#ifdef AMD_EXTENSIONS
// Create group invocation operations on a vector // Create group invocation operations on a vector
spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, spv::Id operand) spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, std::vector<spv::Id>& operands)
{ {
#ifdef AMD_EXTENSIONS
assert(op == spv::OpGroupFMin || op == spv::OpGroupUMin || op == spv::OpGroupSMin || assert(op == spv::OpGroupFMin || op == spv::OpGroupUMin || op == spv::OpGroupSMin ||
op == spv::OpGroupFMax || op == spv::OpGroupUMax || op == spv::OpGroupSMax || op == spv::OpGroupFMax || op == spv::OpGroupUMax || op == spv::OpGroupSMax ||
op == spv::OpGroupFAdd || op == spv::OpGroupIAdd || op == spv::OpGroupFAdd || op == spv::OpGroupIAdd || op == spv::OpGroupBroadcast ||
op == spv::OpGroupFMinNonUniformAMD || op == spv::OpGroupUMinNonUniformAMD || op == spv::OpGroupSMinNonUniformAMD || op == spv::OpGroupFMinNonUniformAMD || op == spv::OpGroupUMinNonUniformAMD || op == spv::OpGroupSMinNonUniformAMD ||
op == spv::OpGroupFMaxNonUniformAMD || op == spv::OpGroupUMaxNonUniformAMD || op == spv::OpGroupSMaxNonUniformAMD || op == spv::OpGroupFMaxNonUniformAMD || op == spv::OpGroupUMaxNonUniformAMD || op == spv::OpGroupSMaxNonUniformAMD ||
op == spv::OpGroupFAddNonUniformAMD || op == spv::OpGroupIAddNonUniformAMD); op == spv::OpGroupFAddNonUniformAMD || op == spv::OpGroupIAddNonUniformAMD);
#else
assert(op == spv::OpGroupFMin || op == spv::OpGroupUMin || op == spv::OpGroupSMin ||
op == spv::OpGroupFMax || op == spv::OpGroupUMax || op == spv::OpGroupSMax ||
op == spv::OpGroupFAdd || op == spv::OpGroupIAdd || op == spv::OpGroupBroadcast);
#endif
// Handle group invocation operations scalar by scalar. // Handle group invocation operations scalar by scalar.
// The result type is the same type as the original type. // The result type is the same type as the original type.
...@@ -4150,28 +4155,32 @@ spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv ...@@ -4150,28 +4155,32 @@ spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv
// - make a vector out the scalar results // - make a vector out the scalar results
// get the types sorted out // get the types sorted out
int numComponents = builder.getNumComponents(operand); int numComponents = builder.getNumComponents(operands[0]);
spv::Id scalarType = builder.getScalarTypeId(builder.getTypeId(operand)); spv::Id scalarType = builder.getScalarTypeId(builder.getTypeId(operands[0]));
std::vector<spv::Id> results; std::vector<spv::Id> results;
// do each scalar op // do each scalar op
for (int comp = 0; comp < numComponents; ++comp) { for (int comp = 0; comp < numComponents; ++comp) {
std::vector<unsigned int> indexes; std::vector<unsigned int> indexes;
indexes.push_back(comp); indexes.push_back(comp);
spv::Id scalar = builder.createCompositeExtract(operand, scalarType, indexes); spv::Id scalar = builder.createCompositeExtract(operands[0], scalarType, indexes);
std::vector<spv::Id> operands; std::vector<spv::Id> spvGroupOperands;
operands.push_back(builder.makeUintConstant(spv::ScopeSubgroup)); spvGroupOperands.push_back(builder.makeUintConstant(spv::ScopeSubgroup));
operands.push_back(spv::GroupOperationReduce); if (op == spv::OpGroupBroadcast) {
operands.push_back(scalar); spvGroupOperands.push_back(scalar);
spvGroupOperands.push_back(operands[1]);
} else {
spvGroupOperands.push_back(spv::GroupOperationReduce);
spvGroupOperands.push_back(scalar);
}
results.push_back(builder.createOp(op, scalarType, operands)); results.push_back(builder.createOp(op, scalarType, spvGroupOperands));
} }
// put the pieces together // put the pieces together
return builder.createCompositeConstruct(typeId, results); return builder.createCompositeConstruct(typeId, results);
} }
#endif
spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy) spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment