Commit 13545206 by qining

Spec Constant Operations

Approach: Add a flag in `Builder` to indicate 'spec constant mode' and 'normal mode'. When the builder is in 'normal mode', nothing changed. When the builder is in 'spec constant mode', binary, unary and other instruction creation rountines will be redirected to `createSpecConstantOp()` to create instrution at module level with `OpSpecConstantOp <original opcode> <operands>`. 'spec constant mode' should be enabled if and only if we are creating spec constants. So a flager setter/recover guard is added when handling binary/unary nodes in `createSpvConstantsFromConstSubTree()`. Note when handling spec constants which are represented as ConstantUnion Node, we should not use `OpSpecConstantOp` to initialize the composite constant, so builder is set to 'normal mode'. Tests: Tests are added in Test/spv.specConstantOperations.vert, including: 1) Arithmetic, shift opeations for both scalar and composite type spec constants. 2) Size conversion from/to float and double for both scalar and vector. 3) Bitwise and/or/xor for both scalar and vector. 4) Unary negate/not for both scalar and vector. 5) Vector swizzles. 6) Comparisons for scalars. 7) == and != for composite type spec constants Issues: 1) To implement == and != for composite type spec constants, the Spec needs to allow OpAll, OpAny, OpFOrdEqual, OpFUnordEqual, OpOrdNotEqual, OpFUnordNotEqual. Currently none of them are allowed in the Spec.
parent c3869fee
...@@ -130,7 +130,7 @@ protected: ...@@ -130,7 +130,7 @@ protected:
void addMemberDecoration(spv::Id id, int member, spv::Decoration dec, unsigned value); void addMemberDecoration(spv::Id id, int member, spv::Decoration dec, unsigned value);
spv::Id createSpvConstant(const glslang::TIntermTyped&); spv::Id createSpvConstant(const glslang::TIntermTyped&);
spv::Id createSpvConstantFromConstUnionArray(const glslang::TType& type, const glslang::TConstUnionArray&, int& nextConst, bool specConstant); spv::Id createSpvConstantFromConstUnionArray(const glslang::TType& type, const glslang::TConstUnionArray&, int& nextConst, bool specConstant);
spv::Id createSpvConstantFromConstSubTree(const glslang::TIntermTyped* subTree); spv::Id createSpvConstantFromConstSubTree(glslang::TIntermTyped* subTree);
bool isTrivialLeaf(const glslang::TIntermTyped* node); bool isTrivialLeaf(const glslang::TIntermTyped* node);
bool isTrivial(const glslang::TIntermTyped* node); bool isTrivial(const glslang::TIntermTyped* node);
spv::Id createShortCircuit(glslang::TOperator, glslang::TIntermTyped& left, glslang::TIntermTyped& right); spv::Id createShortCircuit(glslang::TOperator, glslang::TIntermTyped& left, glslang::TIntermTyped& right);
...@@ -3854,15 +3854,37 @@ spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstUnionArray(const glsla ...@@ -3854,15 +3854,37 @@ spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstUnionArray(const glsla
return builder.makeCompositeConstant(typeId, spvConsts); return builder.makeCompositeConstant(typeId, spvConsts);
} }
namespace {
class SpecConstOpCodeGenerationSettingGuard{
public:
SpecConstOpCodeGenerationSettingGuard(spv::Builder* builder,
bool shouldGeneratingForSpecConst)
: builder_(builder) {
previous_flag_ = builder->isInSpecConstCodeGenMode();
shouldGeneratingForSpecConst ? builder->setToSpecConstCodeGenMode()
: builder->setToNormalCodeGenMode();
}
~SpecConstOpCodeGenerationSettingGuard() {
previous_flag_ ? builder_->setToSpecConstCodeGenMode()
: builder_->setToNormalCodeGenMode();
}
private:
spv::Builder* builder_;
bool previous_flag_;
};
}
// Create constant ID from const initializer sub tree. // Create constant ID from const initializer sub tree.
spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstSubTree( spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstSubTree(
const glslang::TIntermTyped* subTree) { glslang::TIntermTyped* subTree) {
const glslang::TType& glslangType = subTree->getType(); const glslang::TType& glslangType = subTree->getType();
spv::Id typeId = convertGlslangToSpvType(glslangType); spv::Id typeId = convertGlslangToSpvType(glslangType);
bool is_spec_const = subTree->getType().getQualifier().isSpecConstant(); bool is_spec_const = subTree->getType().getQualifier().isSpecConstant();
if (const glslang::TIntermAggregate* an = subTree->getAsAggregate()) { if (const glslang::TIntermAggregate* an = subTree->getAsAggregate()) {
// Aggregate node, we should generate OpConstantComposite or // Aggregate node, we should generate OpConstantComposite or
// OpSpecConstantComposite instruction. // OpSpecConstantComposite instruction.
std::vector<spv::Id> const_constituents; std::vector<spv::Id> const_constituents;
for (auto NI = an->getSequence().begin(); NI != an->getSequence().end(); for (auto NI = an->getSequence().begin(); NI != an->getSequence().end();
NI++) { NI++) {
...@@ -3881,17 +3903,27 @@ spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstSubTree( ...@@ -3881,17 +3903,27 @@ spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstSubTree(
return const_constituents.front(); return const_constituents.front();
} }
} else if (const glslang::TIntermBinary* bn = subTree->getAsBinaryNode()) { } else if (glslang::TIntermBinary* bn = subTree->getAsBinaryNode()) {
// Binary operation node, we should generate OpSpecConstantOp <binary op> // Binary operation node, we should generate OpSpecConstantOp <binary op>
// This case should only happen when Specialization Constants are involved. // This case should only happen when Specialization Constants are involved.
spv::MissingFunctionality("OpSpecConstantOp <binary op> not implemented");
return spv::NoResult;
} else if (const glslang::TIntermUnary* un = subTree->getAsUnaryNode()) { // Spec constants defined with binary operations and other constants requires
// OpSpecConstantOp instruction.
SpecConstOpCodeGenerationSettingGuard set_to_spec_const_mode(&builder, true);
bn->traverse(this);
return accessChainLoad(bn->getType());
} else if (glslang::TIntermUnary* un = subTree->getAsUnaryNode()) {
// Unary operation node, similar to binary operation node, should only // Unary operation node, similar to binary operation node, should only
// happen when specialization constants are involved. // happen when specialization constants are involved.
spv::MissingFunctionality("OpSpecConstantOp <unary op> not implemented");
return spv::NoResult; // Spec constants defined with unary operations and other constants requires
// OpSpecConstantOp instruction.
SpecConstOpCodeGenerationSettingGuard set_to_spec_const_mode(&builder, true);
un->traverse(this);
return accessChainLoad(un->getType());
} else if (const glslang::TIntermConstantUnion* cn = subTree->getAsConstantUnion()) { } else if (const glslang::TIntermConstantUnion* cn = subTree->getAsConstantUnion()) {
// ConstantUnion node, should redirect to // ConstantUnion node, should redirect to
......
...@@ -64,7 +64,8 @@ Builder::Builder(unsigned int magicNumber) : ...@@ -64,7 +64,8 @@ Builder::Builder(unsigned int magicNumber) :
builderNumber(magicNumber), builderNumber(magicNumber),
buildPoint(0), buildPoint(0),
uniqueId(0), uniqueId(0),
mainFunction(0) mainFunction(0),
generatingOpCodeForSpecConst(false)
{ {
clearAccessChain(); clearAccessChain();
} }
...@@ -1063,6 +1064,11 @@ Id Builder::createArrayLength(Id base, unsigned int member) ...@@ -1063,6 +1064,11 @@ Id Builder::createArrayLength(Id base, unsigned int member)
Id Builder::createCompositeExtract(Id composite, Id typeId, unsigned index) Id Builder::createCompositeExtract(Id composite, Id typeId, unsigned index)
{ {
// Generate code for spec constants if in spec constant operation
// generation mode.
if (generatingOpCodeForSpecConst) {
return createSpecConstantOp(OpCompositeExtract, typeId, {composite}, {index});
}
Instruction* extract = new Instruction(getUniqueId(), typeId, OpCompositeExtract); Instruction* extract = new Instruction(getUniqueId(), typeId, OpCompositeExtract);
extract->addIdOperand(composite); extract->addIdOperand(composite);
extract->addImmediateOperand(index); extract->addImmediateOperand(index);
...@@ -1073,6 +1079,11 @@ Id Builder::createCompositeExtract(Id composite, Id typeId, unsigned index) ...@@ -1073,6 +1079,11 @@ Id Builder::createCompositeExtract(Id composite, Id typeId, unsigned index)
Id Builder::createCompositeExtract(Id composite, Id typeId, std::vector<unsigned>& indexes) Id Builder::createCompositeExtract(Id composite, Id typeId, std::vector<unsigned>& indexes)
{ {
// Generate code for spec constants if in spec constant operation
// generation mode.
if (generatingOpCodeForSpecConst) {
return createSpecConstantOp(OpCompositeExtract, typeId, {composite}, indexes);
}
Instruction* extract = new Instruction(getUniqueId(), typeId, OpCompositeExtract); Instruction* extract = new Instruction(getUniqueId(), typeId, OpCompositeExtract);
extract->addIdOperand(composite); extract->addIdOperand(composite);
for (int i = 0; i < (int)indexes.size(); ++i) for (int i = 0; i < (int)indexes.size(); ++i)
...@@ -1170,6 +1181,11 @@ void Builder::createMemoryBarrier(unsigned executionScope, unsigned memorySemant ...@@ -1170,6 +1181,11 @@ void Builder::createMemoryBarrier(unsigned executionScope, unsigned memorySemant
// An opcode that has one operands, a result id, and a type // An opcode that has one operands, a result id, and a type
Id Builder::createUnaryOp(Op opCode, Id typeId, Id operand) Id Builder::createUnaryOp(Op opCode, Id typeId, Id operand)
{ {
// Generate code for spec constants if in spec constant operation
// generation mode.
if (generatingOpCodeForSpecConst) {
return createSpecConstantOp(opCode, typeId, {operand}, {});
}
Instruction* op = new Instruction(getUniqueId(), typeId, opCode); Instruction* op = new Instruction(getUniqueId(), typeId, opCode);
op->addIdOperand(operand); op->addIdOperand(operand);
buildPoint->addInstruction(std::unique_ptr<Instruction>(op)); buildPoint->addInstruction(std::unique_ptr<Instruction>(op));
...@@ -1179,6 +1195,11 @@ Id Builder::createUnaryOp(Op opCode, Id typeId, Id operand) ...@@ -1179,6 +1195,11 @@ Id Builder::createUnaryOp(Op opCode, Id typeId, Id operand)
Id Builder::createBinOp(Op opCode, Id typeId, Id left, Id right) Id Builder::createBinOp(Op opCode, Id typeId, Id left, Id right)
{ {
// Generate code for spec constants if in spec constant operation
// generation mode.
if (generatingOpCodeForSpecConst) {
return createSpecConstantOp(opCode, typeId, {left, right}, {});
}
Instruction* op = new Instruction(getUniqueId(), typeId, opCode); Instruction* op = new Instruction(getUniqueId(), typeId, opCode);
op->addIdOperand(left); op->addIdOperand(left);
op->addIdOperand(right); op->addIdOperand(right);
...@@ -1208,6 +1229,102 @@ Id Builder::createOp(Op opCode, Id typeId, const std::vector<Id>& operands) ...@@ -1208,6 +1229,102 @@ Id Builder::createOp(Op opCode, Id typeId, const std::vector<Id>& operands)
return op->getResultId(); return op->getResultId();
} }
Id Builder::createSpecConstantOp(Op opCode, Id typeId, const std::vector<Id>& operands, const std::vector<unsigned>& literals) {
switch(opCode) {
// OpCodes that do not need any capababilities.
case OpSConvert:
case OpFConvert:
case OpSNegate:
case OpNot:
case OpIAdd:
case OpISub:
case OpIMul:
case OpUDiv:
case OpSDiv:
case OpUMod:
case OpSRem:
case OpSMod:
case OpShiftRightLogical:
case OpShiftRightArithmetic:
case OpShiftLeftLogical:
case OpBitwiseOr:
case OpBitwiseXor:
case OpBitwiseAnd:
case OpVectorShuffle:
case OpCompositeExtract:
case OpCompositeInsert:
case OpLogicalOr:
case OpLogicalAnd:
case OpLogicalNot:
case OpLogicalEqual:
case OpLogicalNotEqual:
case OpSelect:
case OpIEqual:
case OpULessThan:
case OpSLessThan:
case OpUGreaterThan:
case OpSGreaterThan:
case OpULessThanEqual:
case OpSLessThanEqual:
case OpUGreaterThanEqual:
case OpSGreaterThanEqual:
// Added temporarily to enable compsite type spec constants comparison.
// Remove this comment after Spec being updated.
case OpAll:
case OpAny:
case OpFOrdEqual:
case OpFUnordEqual:
case OpFOrdNotEqual:
case OpFUnordNotEqual:
break;
// OpCodes that need Shader capability.
case OpQuantizeToF16:
addCapability(CapabilityShader);
break;
// OpCodes that need Kernel capability.
case OpConvertFToS:
case OpConvertSToF:
case OpConvertFToU:
case OpConvertUToF:
case OpUConvert:
case OpConvertPtrToU:
case OpConvertUToPtr:
case OpGenericCastToPtr:
case OpPtrCastToGeneric:
case OpBitcast:
case OpFNegate:
case OpFAdd:
case OpFSub:
case OpFMul:
case OpFDiv:
case OpFRem:
case OpFMod:
case OpAccessChain:
case OpInBoundsAccessChain:
case OpPtrAccessChain:
case OpInBoundsPtrAccessChain:
addCapability(CapabilityKernel);
break;
default:
// Invalid OpCode for Spec Constant operations.
return NoResult;
}
Instruction* op = new Instruction(getUniqueId(), typeId, OpSpecConstantOp);
op->addImmediateOperand((unsigned) opCode);
for (auto it = operands.cbegin(); it != operands.cend(); ++it)
op->addIdOperand(*it);
for (auto it = literals.cbegin(); it != literals.cend(); ++it)
op->addImmediateOperand(*it);
module.mapInstruction(op);
constantsTypesGlobals.push_back(std::unique_ptr<Instruction>(op));
return op->getResultId();
}
Id Builder::createFunctionCall(spv::Function* function, std::vector<spv::Id>& args) Id Builder::createFunctionCall(spv::Function* function, std::vector<spv::Id>& args)
{ {
Instruction* op = new Instruction(getUniqueId(), function->getReturnType(), OpFunctionCall); Instruction* op = new Instruction(getUniqueId(), function->getReturnType(), OpFunctionCall);
...@@ -1225,6 +1342,9 @@ Id Builder::createRvalueSwizzle(Decoration precision, Id typeId, Id source, std: ...@@ -1225,6 +1342,9 @@ Id Builder::createRvalueSwizzle(Decoration precision, Id typeId, Id source, std:
if (channels.size() == 1) if (channels.size() == 1)
return setPrecision(createCompositeExtract(source, typeId, channels.front()), precision); return setPrecision(createCompositeExtract(source, typeId, channels.front()), precision);
if (generatingOpCodeForSpecConst) {
return setPrecision(createSpecConstantOp(OpVectorShuffle, typeId, {source, source}, channels), precision);
}
Instruction* swizzle = new Instruction(getUniqueId(), typeId, OpVectorShuffle); Instruction* swizzle = new Instruction(getUniqueId(), typeId, OpVectorShuffle);
assert(isVector(source)); assert(isVector(source));
swizzle->addIdOperand(source); swizzle->addIdOperand(source);
...@@ -1290,10 +1410,23 @@ Id Builder::smearScalar(Decoration precision, Id scalar, Id vectorType) ...@@ -1290,10 +1410,23 @@ Id Builder::smearScalar(Decoration precision, Id scalar, Id vectorType)
if (numComponents == 1) if (numComponents == 1)
return scalar; return scalar;
Instruction* smear = new Instruction(getUniqueId(), vectorType, OpCompositeConstruct); Instruction* smear = nullptr;
for (int c = 0; c < numComponents; ++c) if (generatingOpCodeForSpecConst) {
smear->addIdOperand(scalar); auto members = std::vector<spv::Id>(numComponents, scalar);
buildPoint->addInstruction(std::unique_ptr<Instruction>(smear)); // 'scalar' can not be spec constant here. All spec constant involved
// promotion is done in createSpvConstantFromConstUnionArray(). This
// 'if' branch is only accessed when 'scalar' is used in the def-chain
// of other vector type spec constants. In such cases, all the
// instructions needed to promote 'scalar' to a vector type constants
// should be added at module level.
auto result_id = makeCompositeConstant(vectorType, members, false);
smear = module.getInstruction(result_id);
} else {
smear = new Instruction(getUniqueId(), vectorType, OpCompositeConstruct);
for (int c = 0; c < numComponents; ++c)
smear->addIdOperand(scalar);
buildPoint->addInstruction(std::unique_ptr<Instruction>(smear));
}
return setPrecision(smear->getResultId(), precision); return setPrecision(smear->getResultId(), precision);
} }
......
...@@ -262,6 +262,7 @@ public: ...@@ -262,6 +262,7 @@ public:
Id createTriOp(Op, Id typeId, Id operand1, Id operand2, Id operand3); Id createTriOp(Op, Id typeId, Id operand1, Id operand2, Id operand3);
Id createOp(Op, Id typeId, const std::vector<Id>& operands); Id createOp(Op, Id typeId, const std::vector<Id>& operands);
Id createFunctionCall(spv::Function*, std::vector<spv::Id>&); Id createFunctionCall(spv::Function*, std::vector<spv::Id>&);
Id createSpecConstantOp(Op, Id typeId, const std::vector<spv::Id>& operands, const std::vector<unsigned>& literals);
// Take an rvalue (source) and a set of channels to extract from it to // Take an rvalue (source) and a set of channels to extract from it to
// make a new rvalue, which is returned. // make a new rvalue, which is returned.
...@@ -521,6 +522,13 @@ public: ...@@ -521,6 +522,13 @@ public:
void createConditionalBranch(Id condition, Block* thenBlock, Block* elseBlock); void createConditionalBranch(Id condition, Block* thenBlock, Block* elseBlock);
void createLoopMerge(Block* mergeBlock, Block* continueBlock, unsigned int control); void createLoopMerge(Block* mergeBlock, Block* continueBlock, unsigned int control);
// Sets to generate opcode for specialization constants.
void setToSpecConstCodeGenMode() { generatingOpCodeForSpecConst = true; }
// Sets to generate opcode for non-specialization constants (normal mode).
void setToNormalCodeGenMode() { generatingOpCodeForSpecConst = false; }
// Check if the builder is generating code for spec constants.
bool isInSpecConstCodeGenMode() { return generatingOpCodeForSpecConst; }
protected: protected:
Id makeIntConstant(Id typeId, unsigned value, bool specConstant); Id makeIntConstant(Id typeId, unsigned value, bool specConstant);
Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) const; Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value) const;
...@@ -544,6 +552,7 @@ public: ...@@ -544,6 +552,7 @@ public:
Block* buildPoint; Block* buildPoint;
Id uniqueId; Id uniqueId;
Function* mainFunction; Function* mainFunction;
bool generatingOpCodeForSpecConst;
AccessChain accessChain; AccessChain accessChain;
// special blocks of instructions for output // special blocks of instructions for output
......
#version 450
layout(constant_id = 200) const float sp_float = 3.1415926;
layout(constant_id = 201) const int sp_int = 10;
layout(constant_id = 202) const uint sp_uint = 100;
layout(constant_id = 203) const int sp_sint = -10;
//
// Scalars
//
// Size convert
const double float_to_double = double(sp_float);
const float double_to_float = float(float_to_double);
// Negate and Not
const int negate_int = -sp_int;
const int not_int = ~sp_int;
// Add and Subtract
const int sp_int_add_two = sp_int + 2;
const int sp_int_add_two_sub_three = sp_int + 2 - 3;
const int sp_int_add_two_sub_four = sp_int_add_two - 4;
// Mul, Div and Rem
const int sp_sint_mul_two = sp_sint * 2;
const uint sp_uint_mul_two = sp_uint * 2;
const int sp_sint_mul_two_div_five = sp_sint_mul_two / 5;
const uint sp_uint_mul_two_div_five = sp_uint_mul_two / 5;
const int sp_sint_rem_four = sp_sint % 4;
const uint sp_uint_rem_four = sp_uint % 4;
const int sp_sint_mul_three_div_five = sp_sint * 3 / 5;
// Shift
const int sp_sint_shift_right_arithmetic = sp_sint >> 10;
const uint sp_uint_shift_right_arithmetic = sp_uint >> 20;
const int sp_sint_shift_left = sp_sint << 1;
const uint sp_uint_shift_left = sp_uint << 2;
// Bitwise And, Or, Xor
const int sp_sint_or_256 = sp_sint | 0x100;
const uint sp_uint_xor_512 = sp_uint ^ 0x200;
const int sp_sint_and_sp_uint = sp_sint & int(sp_uint);
// Scalar comparison
const bool sp_int_lt_sp_sint = sp_int < sp_sint;
const bool sp_uint_lt_sp_sint = sp_uint < sp_sint;
const bool sp_sint_lt_sp_uint = sp_sint < sp_uint;
//
// Vectors
//
const ivec4 iv = ivec4(20, 30, sp_int, sp_int);
const uvec4 uv = uvec4(sp_uint, sp_uint, -1, -2);
const vec4 fv = vec4(sp_float, 1.25, sp_float, 1.25);
// Size convert
const dvec4 fv_to_dv = dvec4(fv);
const vec4 dv_to_fv = vec4(fv_to_dv);
// Negate and Not
const vec4 not_iv = ~iv;
const ivec4 negate_iv = -iv;
// Add and Subtract
const ivec4 iv_add_two = iv + 2;
const ivec4 iv_add_two_sub_three = iv + 2 - 3;
const ivec4 iv_add_two_sub_four = iv_add_two_sub_three - 4;
// Mul, Div and Rem
const ivec4 iv_mul_two = iv * 2;
const ivec4 iv_mul_two_div_five = iv_mul_two / 5;
const ivec4 iv_rem_four = iv % 4;
// Shift
const ivec4 iv_shift_right_arithmetic = iv >> 10;
const ivec4 iv_shift_left = iv << 2;
// Bitwise And, Or, Xor
const ivec4 iv_or_1024 = iv | 0x400;
const uvec4 uv_xor_2048 = uv ^ 0x800;
const ivec4 iv_and_uv = iv & ivec4(uv);
// Swizzles
const int iv_x = iv.x;
const ivec2 iv_yx = iv.yx;
const ivec3 iv_zyx = iv.zyx;
const ivec4 iv_yzxw = iv.yzxw;
// Vector comparison, only == and != are supported and allowd.
const bool iv_equal_uv = iv == uv;
const bool iv_not_equal_uv = iv != uv;
//
// Composite types other than vectors
//
// Struct
struct int_float_double_vec2 {
int i;
float f;
double d;
vec2 v;
};
const int_float_double_vec2 sp_struct_a = {
sp_int, sp_float, float_to_double,
vec2(double_to_float, 1.0)
};
const int_float_double_vec2 sp_struct_b = {
sp_int, sp_float, float_to_double,
vec2(double_to_float, 1.0)
};
const bool struct_a_equal_struct_b = sp_struct_a == sp_struct_b;
const bool struct_a_not_equal_struct_b = sp_struct_a != sp_struct_b;
// Array
const float array_a[2] = {sp_float, sp_float};
const float array_b[2] = {sp_float, sp_float};
const bool array_a_equal_array_b = array_a == array_b;
const bool array_a_not_equal_array_b = array_a != array_b;
void main() {}
...@@ -103,6 +103,7 @@ spv.subpass.frag ...@@ -103,6 +103,7 @@ spv.subpass.frag
spv.specConstant.vert spv.specConstant.vert
spv.specConstant.comp spv.specConstant.comp
spv.specConstantComposite.vert spv.specConstantComposite.vert
spv.specConstantOperations.vert
# GLSL-level semantics # GLSL-level semantics
vulkan.frag vulkan.frag
vulkan.vert vulkan.vert
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment