Commit 2211835b by John Kessenich

SPV: Implement composite comparisons (reductions across hierchical compare).

parent 59420fd3
...@@ -2274,7 +2274,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv ...@@ -2274,7 +2274,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
if (reduceComparison && (builder.isVector(left) || builder.isMatrix(left) || builder.isAggregate(left))) { if (reduceComparison && (builder.isVector(left) || builder.isMatrix(left) || builder.isAggregate(left))) {
assert(op == glslang::EOpEqual || op == glslang::EOpNotEqual); assert(op == glslang::EOpEqual || op == glslang::EOpNotEqual);
return builder.createCompare(precision, left, right, op == glslang::EOpEqual); return builder.createCompositeCompare(precision, left, right, op == glslang::EOpEqual);
} }
switch (op) { switch (op) {
......
...@@ -435,7 +435,7 @@ Op Builder::getMostBasicTypeClass(Id typeId) const ...@@ -435,7 +435,7 @@ Op Builder::getMostBasicTypeClass(Id typeId) const
} }
} }
int Builder::getNumTypeComponents(Id typeId) const int Builder::getNumTypeConstituents(Id typeId) const
{ {
Instruction* instr = module.getInstruction(typeId); Instruction* instr = module.getInstruction(typeId);
...@@ -447,7 +447,10 @@ int Builder::getNumTypeComponents(Id typeId) const ...@@ -447,7 +447,10 @@ int Builder::getNumTypeComponents(Id typeId) const
return 1; return 1;
case OpTypeVector: case OpTypeVector:
case OpTypeMatrix: case OpTypeMatrix:
case OpTypeArray:
return instr->getImmediateOperand(1); return instr->getImmediateOperand(1);
case OpTypeStruct:
return instr->getNumOperands();
default: default:
assert(0); assert(0);
return 1; return 1;
...@@ -1411,88 +1414,78 @@ Id Builder::createTextureQueryCall(Op opCode, const TextureParameters& parameter ...@@ -1411,88 +1414,78 @@ Id Builder::createTextureQueryCall(Op opCode, const TextureParameters& parameter
return query->getResultId(); return query->getResultId();
} }
// Comments in header // External comments in header.
Id Builder::createCompare(Decoration precision, Id value1, Id value2, bool equal) // Operates recursively to visit the composite's hierarchy.
Id Builder::createCompositeCompare(Decoration precision, Id value1, Id value2, bool equal)
{ {
Id boolType = makeBoolType(); Id boolType = makeBoolType();
Id valueType = getTypeId(value1); Id valueType = getTypeId(value1);
assert(valueType == getTypeId(value2)); assert(valueType == getTypeId(value2));
assert(! isScalar(value1));
// Vectors Id resultId;
int numConstituents = getNumTypeConstituents(valueType);
// Scalars and Vectors
if (isVectorType(valueType)) { if (isScalarType(valueType) || isVectorType(valueType)) {
Id boolVectorType = makeVectorType(boolType, getNumTypeComponents(valueType)); // These just need a single comparison, just have
Id boolVector; // to figure out what it is.
Op op; Op op;
if (getMostBasicTypeClass(valueType) == OpTypeFloat) switch (getMostBasicTypeClass(valueType)) {
case OpTypeFloat:
op = equal ? OpFOrdEqual : OpFOrdNotEqual; op = equal ? OpFOrdEqual : OpFOrdNotEqual;
else break;
case OpTypeInt:
op = equal ? OpIEqual : OpINotEqual; op = equal ? OpIEqual : OpINotEqual;
break;
case OpTypeBool:
op = equal ? OpLogicalEqual : OpLogicalNotEqual;
precision = NoPrecision;
break;
}
boolVector = createBinOp(op, boolVectorType, value1, value2); if (isScalarType(valueType)) {
setPrecision(boolVector, precision); // scalar
resultId = createBinOp(op, boolType, value1, value2);
// Reduce vector compares with any() and all(). setPrecision(resultId, precision);
} else {
op = equal ? OpAll : OpAny; // vector
resultId = createBinOp(op, makeVectorType(boolType, numConstituents), value1, value2);
setPrecision(resultId, precision);
// reduce vector compares...
resultId = createUnaryOp(equal ? OpAll : OpAny, boolType, resultId);
}
return createUnaryOp(op, boolType, boolVector); return resultId;
} }
spv::MissingFunctionality("Composite comparison of non-vectors"); // Only structs, arrays, and matrices should be left.
// They share in common the reduction operation across their constituents.
return NoResult; assert(isAggregateType(valueType) || isMatrixType(valueType));
// Recursively handle aggregates, which include matrices, arrays, and structures
// and accumulate the results.
// Matrices
// Arrays // Compare each pair of constituents
for (int constituent = 0; constituent < numConstituents; ++constituent) {
std::vector<unsigned> indexes(1, constituent);
Id constituentType = getContainedTypeId(valueType, constituent);
Id constituent1 = createCompositeExtract(value1, constituentType, indexes);
Id constituent2 = createCompositeExtract(value2, constituentType, indexes);
//int numElements; Id subResultId = createCompositeCompare(precision, constituent1, constituent2, equal);
//const llvm::ArrayType* arrayType = llvm::dyn_cast<llvm::ArrayType>(value1->getType());
//if (arrayType)
// numElements = (int)arrayType->getNumElements();
//else {
// // better be structure
// const llvm::StructType* structType = llvm::dyn_cast<llvm::StructType>(value1->getType());
// assert(structType);
// numElements = structType->getNumElements();
//}
//assert(numElements > 0); if (constituent == 0)
resultId = subResultId;
//for (int element = 0; element < numElements; ++element) { else
// // Get intermediate comparison values resultId = createBinOp(equal ? OpLogicalAnd : OpLogicalOr, boolType, resultId, subResultId);
// llvm::Value* element1 = builder.CreateExtractValue(value1, element, "element1"); }
// setInstructionPrecision(element1, precision);
// llvm::Value* element2 = builder.CreateExtractValue(value2, element, "element2");
// setInstructionPrecision(element2, precision);
// llvm::Value* subResult = createCompare(precision, element1, element2, equal, "comp");
// // Accumulate intermediate comparison
// if (element == 0)
// result = subResult;
// else {
// if (equal)
// result = builder.CreateAnd(result, subResult);
// else
// result = builder.CreateOr(result, subResult);
// setInstructionPrecision(result, precision);
// }
//}
//return result; return resultId;
} }
// OpCompositeConstruct // OpCompositeConstruct
Id Builder::createCompositeConstruct(Id typeId, std::vector<Id>& constituents) Id Builder::createCompositeConstruct(Id typeId, std::vector<Id>& constituents)
{ {
assert(isAggregateType(typeId) || (getNumTypeComponents(typeId) > 1 && getNumTypeComponents(typeId) == (int)constituents.size())); assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 && getNumTypeConstituents(typeId) == (int)constituents.size()));
Instruction* op = new Instruction(getUniqueId(), typeId, OpCompositeConstruct); Instruction* op = new Instruction(getUniqueId(), typeId, OpCompositeConstruct);
for (int c = 0; c < (int)constituents.size(); ++c) for (int c = 0; c < (int)constituents.size(); ++c)
......
...@@ -116,7 +116,8 @@ public: ...@@ -116,7 +116,8 @@ public:
Op getTypeClass(Id typeId) const { return getOpCode(typeId); } Op getTypeClass(Id typeId) const { return getOpCode(typeId); }
Op getMostBasicTypeClass(Id typeId) const; Op getMostBasicTypeClass(Id typeId) const;
int getNumComponents(Id resultId) const { return getNumTypeComponents(getTypeId(resultId)); } int getNumComponents(Id resultId) const { return getNumTypeComponents(getTypeId(resultId)); }
int getNumTypeComponents(Id typeId) const; int getNumTypeConstituents(Id typeId) const;
int getNumTypeComponents(Id typeId) const { return getNumTypeConstituents(typeId); }
Id getScalarTypeId(Id typeId) const; Id getScalarTypeId(Id typeId) const;
Id getContainedTypeId(Id typeId) const; Id getContainedTypeId(Id typeId) const;
Id getContainedTypeId(Id typeId, int) const; Id getContainedTypeId(Id typeId, int) const;
...@@ -150,7 +151,7 @@ public: ...@@ -150,7 +151,7 @@ public:
int getTypeNumColumns(Id typeId) const int getTypeNumColumns(Id typeId) const
{ {
assert(isMatrixType(typeId)); assert(isMatrixType(typeId));
return getNumTypeComponents(typeId); return getNumTypeConstituents(typeId);
} }
int getNumColumns(Id resultId) const { return getTypeNumColumns(getTypeId(resultId)); } int getNumColumns(Id resultId) const { return getTypeNumColumns(getTypeId(resultId)); }
int getTypeNumRows(Id typeId) const int getTypeNumRows(Id typeId) const
...@@ -265,11 +266,13 @@ public: ...@@ -265,11 +266,13 @@ public:
// (No true lvalue or stores are used.) // (No true lvalue or stores are used.)
Id createLvalueSwizzle(Id typeId, Id target, Id source, std::vector<unsigned>& channels); Id createLvalueSwizzle(Id typeId, Id target, Id source, std::vector<unsigned>& channels);
// If the value passed in is an instruction and the precision is not EMpNone, // If the value passed in is an instruction and the precision is not NoPrecision,
// it gets tagged with the requested precision. // it gets tagged with the requested precision.
void setPrecision(Id /* value */, Decoration /* precision */) void setPrecision(Id /* value */, Decoration precision)
{ {
// TODO if (precision != NoPrecision) {
;// TODO
}
} }
// Can smear a scalar to a vector for the following forms: // Can smear a scalar to a vector for the following forms:
...@@ -322,7 +325,7 @@ public: ...@@ -322,7 +325,7 @@ public:
Id createBitFieldInsertCall(Decoration precision, Id, Id, Id, Id); Id createBitFieldInsertCall(Decoration precision, Id, Id, Id, Id);
// Reduction comparision for composites: For equal and not-equal resulting in a scalar. // Reduction comparision for composites: For equal and not-equal resulting in a scalar.
Id createCompare(Decoration precision, Id, Id, bool /* true if for equal, fales if for not-equal */); Id createCompositeCompare(Decoration precision, Id, Id, bool /* true if for equal, false if for not-equal */);
// OpCompositeConstruct // OpCompositeConstruct
Id createCompositeConstruct(Id typeId, std::vector<Id>& constituents); Id createCompositeConstruct(Id typeId, std::vector<Id>& constituents);
......
...@@ -168,7 +168,7 @@ Linked fragment stage: ...@@ -168,7 +168,7 @@ Linked fragment stage:
18(bv2): 16(ptr) FunctionParameter 18(bv2): 16(ptr) FunctionParameter
20: Label 20: Label
27: 15(bvec2) Load 18(bv2) 27: 15(bvec2) Load 18(bv2)
31: 15(bvec2) IEqual 27 30 31: 15(bvec2) LogicalEqual 27 30
32: 14(bool) All 31 32: 14(bool) All 31
ReturnValue 32 ReturnValue 32
FunctionEnd FunctionEnd
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
// For the version, it uses the latest git tag followed by the number of commits. // For the version, it uses the latest git tag followed by the number of commits.
// For the date, it uses the current date (when then script is run). // For the date, it uses the current date (when then script is run).
#define GLSLANG_REVISION "SPIRV99.862" #define GLSLANG_REVISION "SPIRV99.863"
#define GLSLANG_DATE "21-Dec-2015" #define GLSLANG_DATE "21-Dec-2015"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment