Unverified Commit 792a9480 by John Kessenich Committed by GitHub

Merge pull request #1161 from LoopDawg/matmul-truncate

WIP: HLSL: matrix and vector truncations for m*v, v*m, m*m
parents 698bf754 2e629106
// Test v*v, v*m, m*v, and m*m argument clamping.
cbuffer Matrix
{
float4x4 m44;
float4x3 m43;
float3x4 m34;
float2x4 m24;
float4x2 m42;
float4 v4;
float3 v3;
float2 v2;
}
float4 main() : SV_Target0
{
// v*v:
float r00 = mul(v2, v3); // float = float2*float3; // clamp to float2 dot product
float r01 = mul(v4, v2); // float = float4*float2; // clamp to float2 dot product
// v*m
float4 r10 = mul(v3, m44); // float4 = float3 * float4x4; // clamp mat to float3x4;
float4 r11 = mul(v4, m34); // truncate vector to vec3
// m*v
float4 r20 = mul(m44, v3); // float4 = float4x4 * float3; // clamp mat to float4x3;
float4 r21 = mul(m43, v4); // truncate vector to vec3
// // m*m
// float2x3 r30 = mul(m24, m33); // float2x3 = float2x4 * float3x3;
// float3x4 r31 = mul(m33, m24); // float3x4 = float3x3 * float2x4;
// float3x2 r32 = mul(m33, m42); // float3x2 = float3x3 * float4x2;
// float4x3 r33 = mul(m42, m33); // float4x3 = float4x2 * float3x3;
return r10 + r11 + r20 + r21 + r00 + r01; // + r30[0].x + r31[0] + r32[0].x + transpose(r33)[0];
}
...@@ -1152,6 +1152,7 @@ public: ...@@ -1152,6 +1152,7 @@ public:
sampler.clear(); sampler.clear();
qualifier.clear(); qualifier.clear();
qualifier.storage = q; qualifier.storage = q;
assert(!(isMatrix() && vectorSize != 0)); // prevent vectorSize != 0 on matrices
} }
// for explicit precision qualifier // for explicit precision qualifier
TType(TBasicType t, TStorageQualifier q, TPrecisionQualifier p, int vs = 1, int mc = 0, int mr = 0, TType(TBasicType t, TStorageQualifier q, TPrecisionQualifier p, int vs = 1, int mc = 0, int mr = 0,
...@@ -1164,6 +1165,7 @@ public: ...@@ -1164,6 +1165,7 @@ public:
qualifier.storage = q; qualifier.storage = q;
qualifier.precision = p; qualifier.precision = p;
assert(p >= EpqNone && p <= EpqHigh); assert(p >= EpqNone && p <= EpqHigh);
assert(!(isMatrix() && vectorSize != 0)); // prevent vectorSize != 0 on matrices
} }
// for turning a TPublicType into a TType, using a shallow copy // for turning a TPublicType into a TType, using a shallow copy
explicit TType(const TPublicType& p) : explicit TType(const TPublicType& p) :
......
...@@ -238,6 +238,7 @@ INSTANTIATE_TEST_CASE_P( ...@@ -238,6 +238,7 @@ INSTANTIATE_TEST_CASE_P(
{"hlsl.matrixSwizzle.vert", "ShaderFunction"}, {"hlsl.matrixSwizzle.vert", "ShaderFunction"},
{"hlsl.memberFunCall.frag", "main"}, {"hlsl.memberFunCall.frag", "main"},
{"hlsl.mintypes.frag", "main"}, {"hlsl.mintypes.frag", "main"},
{"hlsl.mul-truncate.frag", "main"},
{"hlsl.multiEntry.vert", "RealEntrypoint"}, {"hlsl.multiEntry.vert", "RealEntrypoint"},
{"hlsl.multiReturn.frag", "main"}, {"hlsl.multiReturn.frag", "main"},
{"hlsl.matrixindex.frag", "main"}, {"hlsl.matrixindex.frag", "main"},
......
...@@ -5008,6 +5008,12 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct ...@@ -5008,6 +5008,12 @@ TIntermTyped* HlslParseContext::handleFunctionCall(const TSourceLoc& loc, TFunct
bool builtIn = false; bool builtIn = false;
int thisDepth = 0; int thisDepth = 0;
// For mat mul, the situation is unusual: we have to compare vector sizes to mat row or col sizes,
// and clamp the opposite arg. Since that's complex, we farm it off to a separate method.
// It doesn't naturally fall out of processing an argument at a time in isolation.
if (function->getName() == "mul")
addGenMulArgumentConversion(loc, *function, arguments);
TIntermAggregate* aggregate = arguments ? arguments->getAsAggregate() : nullptr; TIntermAggregate* aggregate = arguments ? arguments->getAsAggregate() : nullptr;
// TODO: this needs improvement: there's no way at present to look up a signature in // TODO: this needs improvement: there's no way at present to look up a signature in
...@@ -5170,6 +5176,68 @@ void HlslParseContext::pushFrontArguments(TIntermTyped* front, TIntermTyped*& ar ...@@ -5170,6 +5176,68 @@ void HlslParseContext::pushFrontArguments(TIntermTyped* front, TIntermTyped*& ar
} }
// //
// HLSL allows mismatched dimensions on vec*mat, mat*vec, vec*vec, and mat*mat. This is a
// situation not well suited to resolution in intrinsic selection, but we can do so here, since we
// can look at both arguments insert explicit shape changes here, if required.
//
void HlslParseContext::addGenMulArgumentConversion(const TSourceLoc& loc, TFunction& call, TIntermTyped*& args)
{
TIntermAggregate* argAggregate = args ? args->getAsAggregate() : nullptr;
if (argAggregate == nullptr || argAggregate->getSequence().size() != 2) {
// It really ought to have two arguments.
error(loc, "expected: mul arguments", "", "");
return;
}
TIntermTyped* arg0 = argAggregate->getSequence()[0]->getAsTyped();
TIntermTyped* arg1 = argAggregate->getSequence()[1]->getAsTyped();
if (arg0->isVector() && arg1->isVector()) {
// For:
// vec * vec: it's handled during intrinsic selection, so while we could do it here,
// we can also ignore it, which is easier.
} else if (arg0->isVector() && arg1->isMatrix()) {
// vec * mat: we clamp the vec if the mat col is smaller, else clamp the mat col.
if (arg0->getVectorSize() < arg1->getMatrixCols()) {
// vec is smaller, so truncate larger mat dimension
const TType truncType(arg1->getBasicType(), arg1->getQualifier().storage, arg1->getQualifier().precision,
0, arg0->getVectorSize(), arg1->getMatrixRows());
arg1 = addConstructor(loc, arg1, truncType);
} else if (arg0->getVectorSize() > arg1->getMatrixCols()) {
// vec is larger, so truncate vec to mat size
const TType truncType(arg0->getBasicType(), arg0->getQualifier().storage, arg0->getQualifier().precision,
arg1->getMatrixCols());
arg0 = addConstructor(loc, arg0, truncType);
}
} else if (arg0->isMatrix() && arg1->isVector()) {
// mat * vec: we clamp the vec if the mat col is smaller, else clamp the mat col.
if (arg1->getVectorSize() < arg0->getMatrixRows()) {
// vec is smaller, so truncate larger mat dimension
const TType truncType(arg0->getBasicType(), arg0->getQualifier().storage, arg0->getQualifier().precision,
0, arg0->getMatrixCols(), arg1->getVectorSize());
arg0 = addConstructor(loc, arg0, truncType);
} else if (arg1->getVectorSize() > arg0->getMatrixRows()) {
// vec is larger, so truncate vec to mat size
const TType truncType(arg1->getBasicType(), arg1->getQualifier().storage, arg1->getQualifier().precision,
arg0->getMatrixRows());
arg1 = addConstructor(loc, arg1, truncType);
}
} else if (arg0->isMatrix() && arg1->isMatrix()) {
// mat * mat
} else {
// It's something with scalars: we'll just leave it alone.
}
// Put arguments back.
argAggregate->getSequence()[0] = arg0;
argAggregate->getSequence()[1] = arg1;
call[0].type = &arg0->getWritableType();
call[1].type = &arg1->getWritableType();
}
//
// Add any needed implicit conversions for function-call arguments to input parameters. // Add any needed implicit conversions for function-call arguments to input parameters.
// //
void HlslParseContext::addInputArgumentConversions(const TFunction& function, TIntermTyped*& arguments) void HlslParseContext::addInputArgumentConversions(const TFunction& function, TIntermTyped*& arguments)
...@@ -7015,6 +7083,7 @@ void HlslParseContext::mergeObjectLayoutQualifiers(TQualifier& dst, const TQuali ...@@ -7015,6 +7083,7 @@ void HlslParseContext::mergeObjectLayoutQualifiers(TQualifier& dst, const TQuali
} }
} }
// //
// Look up a function name in the symbol table, and make sure it is a function. // Look up a function name in the symbol table, and make sure it is a function.
// //
......
...@@ -141,6 +141,7 @@ public: ...@@ -141,6 +141,7 @@ public:
void checkNoShaderLayouts(const TSourceLoc&, const TShaderQualifiers&); void checkNoShaderLayouts(const TSourceLoc&, const TShaderQualifiers&);
const TFunction* findFunction(const TSourceLoc& loc, TFunction& call, bool& builtIn, int& thisDepth, TIntermTyped*& args); const TFunction* findFunction(const TSourceLoc& loc, TFunction& call, bool& builtIn, int& thisDepth, TIntermTyped*& args);
void addGenMulArgumentConversion(const TSourceLoc& loc, TFunction& call, TIntermTyped*& args);
void declareTypedef(const TSourceLoc&, const TString& identifier, const TType&); void declareTypedef(const TSourceLoc&, const TString& identifier, const TType&);
void declareStruct(const TSourceLoc&, TString& structName, TType&); void declareStruct(const TSourceLoc&, TString& structName, TType&);
TSymbol* lookupUserType(const TString&, TType&); TSymbol* lookupUserType(const TString&, TType&);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment