Commit ce443b3a by Neil Henning

Add support for the latest shader model 6.0 wave specs.

MSVC annoying totally redid their wave ops here: https://github.com/Microsoft/DirectXShaderCompiler/wiki/Wave-Intrinsics I've had to totally rework the existing wave ops to support this. ``` [----------] Global test environment tear-down [==========] 1142 tests from 44 test cases ran. (186552 ms total) [ PASSED ] 1142 tests. ```
parent d487d4d0
......@@ -619,6 +619,7 @@ spv::BuiltIn TGlslangToSpvTraverser::TranslateBuiltInDecoration(glslang::TBuiltI
builder.addCapability(spv::CapabilityGroupNonUniform);
builder.addCapability(spv::CapabilityGroupNonUniformBallot);
return spv::BuiltInSubgroupLtMask;
#ifdef AMD_EXTENSIONS
case glslang::EbvBaryCoordNoPersp:
builder.addExtension(spv::E_SPV_AMD_shader_explicit_vertex_parameter);
......
File mode changed from 100755 to 100644
......@@ -7,7 +7,7 @@ gl_FragCoord origin is upper left
0:? Sequence
0:3 Test condition and select ( temp void)
0:3 Condition
0:3 '@gl_HelperInvocation' ( in bool HelperInvocation)
0:3 subgroupElect ( temp bool)
0:3 true case
0:? Sequence
0:5 Branch: Return with expression
......@@ -45,7 +45,7 @@ gl_FragCoord origin is upper left
0:? Sequence
0:3 Test condition and select ( temp void)
0:3 Condition
0:3 '@gl_HelperInvocation' ( in bool HelperInvocation)
0:3 subgroupElect ( temp bool)
0:3 true case
0:? Sequence
0:5 Branch: Return with expression
......@@ -76,16 +76,15 @@ gl_FragCoord origin is upper left
// Id's are bound by 30
Capability Shader
Capability GroupNonUniform
1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450
EntryPoint Fragment 4 "PixelShaderFunction" 13 28
EntryPoint Fragment 4 "PixelShaderFunction" 28
ExecutionMode 4 OriginUpperLeft
Source HLSL 500
Name 4 "PixelShaderFunction"
Name 9 "@PixelShaderFunction("
Name 13 "@gl_HelperInvocation"
Name 28 "@entryPointOutput"
Decorate 13(@gl_HelperInvocation) BuiltIn HelperInvocation
Decorate 28(@entryPointOutput) Location 0
2: TypeVoid
3: TypeFunction 2
......@@ -93,8 +92,8 @@ gl_FragCoord origin is upper left
7: TypeVector 6(float) 4
8: TypeFunction 7(fvec4)
11: TypeBool
12: TypePointer Input 11(bool)
13(@gl_HelperInvocation): 12(ptr) Variable Input
12: TypeInt 32 0
13: 12(int) Constant 3
17: 6(float) Constant 1065353216
18: 6(float) Constant 1073741824
19: 6(float) Constant 1077936128
......@@ -111,7 +110,7 @@ gl_FragCoord origin is upper left
FunctionEnd
9(@PixelShaderFunction(): 7(fvec4) Function None 8
10: Label
14: 11(bool) Load 13(@gl_HelperInvocation)
14: 11(bool) GroupNonUniformElect 13
SelectionMerge 16 None
BranchConditional 14 15 23
15: Label
......
RWStructuredBuffer<uint> data;
[numthreads(32, 16, 1)]
void CSMain()
{
data[WaveGetOrderedIndex()] = 1;
}
float4 PixelShaderFunction() : COLOR0
{
if (0 == WaveGetOrderedIndex())
{
return float4(1, 2, 3, 4);
}
else
{
return float4(4, 3, 2, 1);
}
}
RWStructuredBuffer<uint> data;
[numthreads(32, 16, 1)]
void CSMain()
{
uint i = 42;
data[GlobalOrderedCountIncrement(i)] = 1;
}
float4 PixelShaderFunction() : COLOR0
{
uint i = 42;
if (0 == GlobalOrderedCountIncrement(i))
{
return float4(1, 2, 3, 4);
}
else
{
return float4(4, 3, 2, 1);
}
}
......@@ -50,4 +50,6 @@ void CSMain(uint3 dti : SV_DispatchThreadID)
data[dti.x].d.x = WavePrefixProduct(data[dti.x].d.x);
data[dti.x].d.xy = WavePrefixProduct(data[dti.x].d.xy);
data[dti.x].d.xyz = WavePrefixProduct(data[dti.x].d.xyz);
data[dti.x].u.x = WavePrefixCountBits(data[dti.x].u.x == 0);
}
......@@ -91,43 +91,63 @@ void CSMain(uint3 dti : SV_DispatchThreadID)
data[dti.x].d.xy = QuadReadLaneAt(data[dti.x].d.xy, 3);
data[dti.x].d.xyz = QuadReadLaneAt(data[dti.x].d.xyz, 3);
data[dti.x].u = QuadSwapX(data[dti.x].u);
data[dti.x].u.x = QuadSwapX(data[dti.x].u.x);
data[dti.x].u.xy = QuadSwapX(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadSwapX(data[dti.x].u.xyz);
data[dti.x].i = QuadSwapX(data[dti.x].i);
data[dti.x].i.x = QuadSwapX(data[dti.x].i.x);
data[dti.x].i.xy = QuadSwapX(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadSwapX(data[dti.x].i.xyz);
data[dti.x].f = QuadSwapX(data[dti.x].f);
data[dti.x].f.x = QuadSwapX(data[dti.x].f.x);
data[dti.x].f.xy = QuadSwapX(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadSwapX(data[dti.x].f.xyz);
data[dti.x].d = QuadSwapX(data[dti.x].d);
data[dti.x].d.x = QuadSwapX(data[dti.x].d.x);
data[dti.x].d.xy = QuadSwapX(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadSwapX(data[dti.x].d.xyz);
data[dti.x].u = QuadSwapY(data[dti.x].u);
data[dti.x].u.x = QuadSwapY(data[dti.x].u.x);
data[dti.x].u.xy = QuadSwapY(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadSwapY(data[dti.x].u.xyz);
data[dti.x].i = QuadSwapY(data[dti.x].i);
data[dti.x].i.x = QuadSwapY(data[dti.x].i.x);
data[dti.x].i.xy = QuadSwapY(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadSwapY(data[dti.x].i.xyz);
data[dti.x].f = QuadSwapY(data[dti.x].f);
data[dti.x].f.x = QuadSwapY(data[dti.x].f.x);
data[dti.x].f.xy = QuadSwapY(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadSwapY(data[dti.x].f.xyz);
data[dti.x].d = QuadSwapY(data[dti.x].d);
data[dti.x].d.x = QuadSwapY(data[dti.x].d.x);
data[dti.x].d.xy = QuadSwapY(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadSwapY(data[dti.x].d.xyz);
data[dti.x].u = QuadReadAcrossX(data[dti.x].u);
data[dti.x].u.x = QuadReadAcrossX(data[dti.x].u.x);
data[dti.x].u.xy = QuadReadAcrossX(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadReadAcrossX(data[dti.x].u.xyz);
data[dti.x].i = QuadReadAcrossX(data[dti.x].i);
data[dti.x].i.x = QuadReadAcrossX(data[dti.x].i.x);
data[dti.x].i.xy = QuadReadAcrossX(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadReadAcrossX(data[dti.x].i.xyz);
data[dti.x].f = QuadReadAcrossX(data[dti.x].f);
data[dti.x].f.x = QuadReadAcrossX(data[dti.x].f.x);
data[dti.x].f.xy = QuadReadAcrossX(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadReadAcrossX(data[dti.x].f.xyz);
data[dti.x].d = QuadReadAcrossX(data[dti.x].d);
data[dti.x].d.x = QuadReadAcrossX(data[dti.x].d.x);
data[dti.x].d.xy = QuadReadAcrossX(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadReadAcrossX(data[dti.x].d.xyz);
data[dti.x].u = QuadReadAcrossY(data[dti.x].u);
data[dti.x].u.x = QuadReadAcrossY(data[dti.x].u.x);
data[dti.x].u.xy = QuadReadAcrossY(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadReadAcrossY(data[dti.x].u.xyz);
data[dti.x].i = QuadReadAcrossY(data[dti.x].i);
data[dti.x].i.x = QuadReadAcrossY(data[dti.x].i.x);
data[dti.x].i.xy = QuadReadAcrossY(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadReadAcrossY(data[dti.x].i.xyz);
data[dti.x].f = QuadReadAcrossY(data[dti.x].f);
data[dti.x].f.x = QuadReadAcrossY(data[dti.x].f.x);
data[dti.x].f.xy = QuadReadAcrossY(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadReadAcrossY(data[dti.x].f.xyz);
data[dti.x].d = QuadReadAcrossY(data[dti.x].d);
data[dti.x].d.x = QuadReadAcrossY(data[dti.x].d.x);
data[dti.x].d.xy = QuadReadAcrossY(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadReadAcrossY(data[dti.x].d.xyz);
data[dti.x].u = QuadReadAcrossDiagonal(data[dti.x].u);
data[dti.x].u.x = QuadReadAcrossDiagonal(data[dti.x].u.x);
data[dti.x].u.xy = QuadReadAcrossDiagonal(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadReadAcrossDiagonal(data[dti.x].u.xyz);
data[dti.x].i = QuadReadAcrossDiagonal(data[dti.x].i);
data[dti.x].i.x = QuadReadAcrossDiagonal(data[dti.x].i.x);
data[dti.x].i.xy = QuadReadAcrossDiagonal(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadReadAcrossDiagonal(data[dti.x].i.xyz);
data[dti.x].f = QuadReadAcrossDiagonal(data[dti.x].f);
data[dti.x].f.x = QuadReadAcrossDiagonal(data[dti.x].f.x);
data[dti.x].f.xy = QuadReadAcrossDiagonal(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadReadAcrossDiagonal(data[dti.x].f.xyz);
data[dti.x].d = QuadReadAcrossDiagonal(data[dti.x].d);
data[dti.x].d.x = QuadReadAcrossDiagonal(data[dti.x].d.x);
data[dti.x].d.xy = QuadReadAcrossDiagonal(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadReadAcrossDiagonal(data[dti.x].d.xyz);
}
......@@ -3,5 +3,5 @@ RWStructuredBuffer<uint> data;
[numthreads(32, 16, 1)]
void CSMain()
{
data[WaveGetLaneIndex()] = (WaveOnce()) ? WaveGetLaneCount() : 0;
data[WaveGetLaneIndex()] = (WaveIsFirstLane()) ? WaveGetLaneCount() : 0;
}
float4 PixelShaderFunction() : COLOR0
{
if (WaveIsHelperLane())
if (WaveIsFirstLane())
{
return float4(1, 2, 3, 4);
}
......
......@@ -11,113 +11,115 @@ RWStructuredBuffer<Types> data;
[numthreads(32, 16, 1)]
void CSMain(uint3 dti : SV_DispatchThreadID)
{
data[dti.x].u = WaveAllSum(data[dti.x].u);
data[dti.x].u.x = WaveAllSum(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllSum(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllSum(data[dti.x].u.xyz);
data[dti.x].i = WaveAllSum(data[dti.x].i);
data[dti.x].i.x = WaveAllSum(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllSum(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllSum(data[dti.x].i.xyz);
data[dti.x].f = WaveAllSum(data[dti.x].f);
data[dti.x].f.x = WaveAllSum(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllSum(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllSum(data[dti.x].f.xyz);
data[dti.x].d = WaveAllSum(data[dti.x].d);
data[dti.x].d.x = WaveAllSum(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllSum(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllSum(data[dti.x].d.xyz);
data[dti.x].u = WaveAllProduct(data[dti.x].u);
data[dti.x].u.x = WaveAllProduct(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllProduct(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllProduct(data[dti.x].u.xyz);
data[dti.x].i = WaveAllProduct(data[dti.x].i);
data[dti.x].i.x = WaveAllProduct(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllProduct(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllProduct(data[dti.x].i.xyz);
data[dti.x].f = WaveAllProduct(data[dti.x].f);
data[dti.x].f.x = WaveAllProduct(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllProduct(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllProduct(data[dti.x].f.xyz);
data[dti.x].d = WaveAllProduct(data[dti.x].d);
data[dti.x].d.x = WaveAllProduct(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllProduct(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllProduct(data[dti.x].d.xyz);
data[dti.x].u = WaveAllMin(data[dti.x].u);
data[dti.x].u.x = WaveAllMin(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllMin(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllMin(data[dti.x].u.xyz);
data[dti.x].i = WaveAllMin(data[dti.x].i);
data[dti.x].i.x = WaveAllMin(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllMin(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllMin(data[dti.x].i.xyz);
data[dti.x].f = WaveAllMin(data[dti.x].f);
data[dti.x].f.x = WaveAllMin(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllMin(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllMin(data[dti.x].f.xyz);
data[dti.x].d = WaveAllMin(data[dti.x].d);
data[dti.x].d.x = WaveAllMin(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllMin(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllMin(data[dti.x].d.xyz);
data[dti.x].u = WaveAllMax(data[dti.x].u);
data[dti.x].u.x = WaveAllMax(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllMax(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllMax(data[dti.x].u.xyz);
data[dti.x].i = WaveAllMax(data[dti.x].i);
data[dti.x].i.x = WaveAllMax(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllMax(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllMax(data[dti.x].i.xyz);
data[dti.x].f = WaveAllMax(data[dti.x].f);
data[dti.x].f.x = WaveAllMax(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllMax(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllMax(data[dti.x].f.xyz);
data[dti.x].d = WaveAllMax(data[dti.x].d);
data[dti.x].d.x = WaveAllMax(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllMax(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllMax(data[dti.x].d.xyz);
data[dti.x].u = WaveAllBitAnd(data[dti.x].u);
data[dti.x].u.x = WaveAllBitAnd(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllBitAnd(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllBitAnd(data[dti.x].u.xyz);
data[dti.x].i = WaveAllBitAnd(data[dti.x].i);
data[dti.x].i.x = WaveAllBitAnd(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllBitAnd(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllBitAnd(data[dti.x].i.xyz);
data[dti.x].u = WaveAllBitOr(data[dti.x].u);
data[dti.x].u.x = WaveAllBitOr(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllBitOr(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllBitOr(data[dti.x].u.xyz);
data[dti.x].i = WaveAllBitOr(data[dti.x].i);
data[dti.x].i.x = WaveAllBitOr(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllBitOr(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllBitOr(data[dti.x].i.xyz);
data[dti.x].u = WaveAllBitXor(data[dti.x].u);
data[dti.x].u.x = WaveAllBitXor(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllBitXor(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllBitXor(data[dti.x].u.xyz);
data[dti.x].i = WaveAllBitXor(data[dti.x].i);
data[dti.x].i.x = WaveAllBitXor(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllBitXor(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllBitXor(data[dti.x].i.xyz);
data[dti.x].u = WaveActiveSum(data[dti.x].u);
data[dti.x].u.x = WaveActiveSum(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveSum(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveSum(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveSum(data[dti.x].i);
data[dti.x].i.x = WaveActiveSum(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveSum(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveSum(data[dti.x].i.xyz);
data[dti.x].f = WaveActiveSum(data[dti.x].f);
data[dti.x].f.x = WaveActiveSum(data[dti.x].f.x);
data[dti.x].f.xy = WaveActiveSum(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveActiveSum(data[dti.x].f.xyz);
data[dti.x].d = WaveActiveSum(data[dti.x].d);
data[dti.x].d.x = WaveActiveSum(data[dti.x].d.x);
data[dti.x].d.xy = WaveActiveSum(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveActiveSum(data[dti.x].d.xyz);
data[dti.x].u = WaveActiveProduct(data[dti.x].u);
data[dti.x].u.x = WaveActiveProduct(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveProduct(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveProduct(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveProduct(data[dti.x].i);
data[dti.x].i.x = WaveActiveProduct(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveProduct(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveProduct(data[dti.x].i.xyz);
data[dti.x].f = WaveActiveProduct(data[dti.x].f);
data[dti.x].f.x = WaveActiveProduct(data[dti.x].f.x);
data[dti.x].f.xy = WaveActiveProduct(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveActiveProduct(data[dti.x].f.xyz);
data[dti.x].d = WaveActiveProduct(data[dti.x].d);
data[dti.x].d.x = WaveActiveProduct(data[dti.x].d.x);
data[dti.x].d.xy = WaveActiveProduct(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveActiveProduct(data[dti.x].d.xyz);
data[dti.x].u = WaveActiveMin(data[dti.x].u);
data[dti.x].u.x = WaveActiveMin(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveMin(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveMin(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveMin(data[dti.x].i);
data[dti.x].i.x = WaveActiveMin(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveMin(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveMin(data[dti.x].i.xyz);
data[dti.x].f = WaveActiveMin(data[dti.x].f);
data[dti.x].f.x = WaveActiveMin(data[dti.x].f.x);
data[dti.x].f.xy = WaveActiveMin(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveActiveMin(data[dti.x].f.xyz);
data[dti.x].d = WaveActiveMin(data[dti.x].d);
data[dti.x].d.x = WaveActiveMin(data[dti.x].d.x);
data[dti.x].d.xy = WaveActiveMin(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveActiveMin(data[dti.x].d.xyz);
data[dti.x].u = WaveActiveMax(data[dti.x].u);
data[dti.x].u.x = WaveActiveMax(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveMax(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveMax(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveMax(data[dti.x].i);
data[dti.x].i.x = WaveActiveMax(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveMax(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveMax(data[dti.x].i.xyz);
data[dti.x].f = WaveActiveMax(data[dti.x].f);
data[dti.x].f.x = WaveActiveMax(data[dti.x].f.x);
data[dti.x].f.xy = WaveActiveMax(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveActiveMax(data[dti.x].f.xyz);
data[dti.x].d = WaveActiveMax(data[dti.x].d);
data[dti.x].d.x = WaveActiveMax(data[dti.x].d.x);
data[dti.x].d.xy = WaveActiveMax(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveActiveMax(data[dti.x].d.xyz);
data[dti.x].u = WaveActiveBitAnd(data[dti.x].u);
data[dti.x].u.x = WaveActiveBitAnd(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveBitAnd(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveBitAnd(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveBitAnd(data[dti.x].i);
data[dti.x].i.x = WaveActiveBitAnd(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveBitAnd(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveBitAnd(data[dti.x].i.xyz);
data[dti.x].u = WaveActiveBitOr(data[dti.x].u);
data[dti.x].u.x = WaveActiveBitOr(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveBitOr(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveBitOr(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveBitOr(data[dti.x].i);
data[dti.x].i.x = WaveActiveBitOr(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveBitOr(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveBitOr(data[dti.x].i.xyz);
data[dti.x].u = WaveActiveBitXor(data[dti.x].u);
data[dti.x].u.x = WaveActiveBitXor(data[dti.x].u.x);
data[dti.x].u.xy = WaveActiveBitXor(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveActiveBitXor(data[dti.x].u.xyz);
data[dti.x].i = WaveActiveBitXor(data[dti.x].i);
data[dti.x].i.x = WaveActiveBitXor(data[dti.x].i.x);
data[dti.x].i.xy = WaveActiveBitXor(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveActiveBitXor(data[dti.x].i.xyz);
data[dti.x].u.x = WaveActiveCountBits(data[dti.x].u.x == 0);
}
......@@ -3,7 +3,8 @@ RWStructuredBuffer<uint64_t> data;
[numthreads(32, 16, 1)]
void CSMain(uint3 dti : SV_DispatchThreadID)
{
data[dti.x] = WaveBallot(WaveAnyTrue(dti.x == 0));
data[dti.y] = WaveBallot(WaveAllTrue(dti.y == 0));
data[dti.z] = WaveBallot(WaveAllEqual(dti.z == 0));
data[dti.x] = WaveActiveBallot(WaveActiveAnyTrue(dti.x == 0));
data[dti.y] = WaveActiveBallot(WaveActiveAllTrue(dti.y == 0));
data[dti.z] = WaveActiveBallot(WaveActiveAllEqualBool(dti.z == 0));
data[dti.z] = WaveActiveBallot(WaveActiveAllEqual(dti.z));
}
......@@ -927,10 +927,8 @@ enum TOperator {
// SM6 wave ops
EOpWaveGetLaneCount, // Will decompose to gl_SubgroupSize.
EOpWaveGetLaneIndex, // Will decompose to gl_SubgroupInvocationID.
EOpWaveIsHelperLane, // Will decompose to gl_HelperInvocation.
EOpWaveBallot, // Will decompose to subgroupBallot.
EOpWaveGetOrderedIndex, // Will decompose to an equation containing gl_SubgroupID.
EOpGlobalOrderedCountIncrement, // Will nice error.
EOpWaveActiveCountBits, // Will decompose to subgroupBallotBitCount(subgroupBallot()).
EOpWavePrefixCountBits, // Will decompose to subgroupBallotInclusiveBitCount(subgroupBallot()).
};
class TIntermTraverser;
......
......@@ -367,17 +367,13 @@ INSTANTIATE_TEST_CASE_P(
{"hlsl.type.identifier.frag", "main"},
{"hlsl.typeGraphCopy.vert", "main"},
{"hlsl.typedef.frag", "PixelShaderFunction"},
{"hlsl.wavequery.comp", "CSMain"},
{"hlsl.wavequery.frag", "PixelShaderFunction"},
{"hlsl.wavevote.comp", "CSMain"},
{"hlsl.wavebroadcast.comp", "CSMain"},
{"hlsl.wavereduction.comp", "CSMain"},
{"hlsl.waveprefix.comp", "CSMain"},
{"hlsl.wavequad.comp", "CSMain"},
{"hlsl.waveordered.comp", "CSMain"},
{"hlsl.waveordered2.comp", "CSMain"},
{"hlsl.waveordered.frag", "PixelShaderFunction"},
{"hlsl.waveordered2.frag", "PixelShaderFunction"},
{"hlsl.wavequery.comp", "CSMain"},
{"hlsl.wavequery.frag", "PixelShaderFunction"},
{"hlsl.wavereduction.comp", "CSMain"},
{"hlsl.wavevote.comp", "CSMain"},
{"hlsl.whileLoop.frag", "PixelShaderFunction"},
{"hlsl.void.frag", "PixelShaderFunction"}
}),
......
......@@ -5090,19 +5090,9 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*&
node = lookupBuiltinVariable("@gl_SubgroupInvocationID", EbvSubgroupInvocation2, type);
break;
}
case EOpWaveIsHelperLane:
case EOpWaveActiveCountBits:
{
// Mapped to gl_HelperInvocation builtin (We preprend @ to the symbol
// so that it inhabits the symbol table, but has a user-invalid name
// in-case some source HLSL defined the symbol also).
TType type(EbtBool, EvqVaryingIn);
node = lookupBuiltinVariable("@gl_HelperInvocation", EbvHelperInvocation, type);
break;
}
case EOpWaveBallot:
{
// Mapped to subgroupBallot() builtin (NOTE: if an IHV has
// a subgroup size > 64 these wave ops will not work for them!)
// Mapped to subgroupBallotBitCount(subgroupBallot()) builtin
// uvec4 type.
TType uvec4Type(EbtUint, EvqTemporary, 4);
......@@ -5111,63 +5101,34 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*&
TIntermTyped* res = intermediate.addBuiltInFunctionCall(loc,
EOpSubgroupBallot, true, arguments, uvec4Type);
// And extract a uvec2 for the two highest components.
TIntermTyped* xy = handleDotDereference(loc, res, "xy");
// uint type.
TType uintType(EbtUint, EvqTemporary);
// uint64_t type.
TType uint64Type(EbtUint64, EvqTemporary);
// And bitcast the result for a uint64_t
node = intermediate.addBuiltInFunctionCall(loc,
EOpPackUint2x32, true, xy, uint64Type);
EOpSubgroupBallotBitCount, true, res, uintType);
break;
}
case EOpWaveGetOrderedIndex:
case EOpWavePrefixCountBits:
{
if (language == EShLangFragment) {
// NOTE: For HLSL SM6.0 this should work for PS too, but the current GLSL extensions don't allow this.
error(loc, "WaveGetOrderedIndex() unsupported in a pixel/fragment shader", "WaveGetOrderedIndex", "");
break;
}
TType uintType(EbtUint, EvqVaryingIn);
TIntermTyped* subgroupID = lookupBuiltinVariable("@gl_SubgroupID", EbvSubgroupID, uintType);
TIntermTyped* numSubgroups = lookupBuiltinVariable("@gl_NumSubgroups", EbvNumSubgroups, uintType);
TType uvec3Type(EbtUint, EvqVaryingIn, 3);
TIntermTyped* numWorkGroups = lookupBuiltinVariable("@gl_NumWorkGroups", EbvNumWorkGroups, uvec3Type);
TIntermTyped* workGroupID = lookupBuiltinVariable("@gl_WorkGroupID", EbvWorkGroupId, uvec3Type);
// Mapped to subgroupBallotInclusiveBitCount(subgroupBallot())
// builtin
//x & y components of gl_NumWorkGroups
TIntermTyped* numWorkGroupsX = handleDotDereference(loc, numWorkGroups, "x");
TIntermTyped* numWorkGroupsY = handleDotDereference(loc, numWorkGroups, "y");
// uvec4 type.
TType uvec4Type(EbtUint, EvqTemporary, 4);
// x & y components of globalSize
TIntermTyped* globalSizeX = handleBinaryMath(loc, "mul", EOpMul, numSubgroups, numWorkGroupsX);
TIntermTyped* globalSizeY = numWorkGroupsY;
// Get the uvec4 return from subgroupBallot().
TIntermTyped* res = intermediate.addBuiltInFunctionCall(loc,
EOpSubgroupBallot, true, arguments, uvec4Type);
// x, y & z components of gl_WorkGroupID
TIntermTyped* workGroupX = handleDotDereference(loc, workGroupID, "x");
TIntermTyped* workGroupY = handleDotDereference(loc, workGroupID, "y");
TIntermTyped* workGroupZ = handleDotDereference(loc, workGroupID, "z");
// uint type.
TType uintType(EbtUint, EvqTemporary);
// We're going to build up the following variables to get a uniquely ordered ID:
// (globalSize.y * gl_WorkGroupID.z + gl_WorkGroupID.y) * globalSize.x + gl_WorkGroupID.x + gl_SubgroupID
node = handleBinaryMath(loc, "mul", EOpMul, globalSizeY, workGroupZ);
node = handleBinaryMath(loc, "add", EOpAdd, node, workGroupY);
node = handleBinaryMath(loc, "mul", EOpMul, node, globalSizeX);
node = handleBinaryMath(loc, "add", EOpAdd, node, workGroupX);
node = handleBinaryMath(loc, "add", EOpAdd, node, subgroupID);
node = intermediate.addBuiltInFunctionCall(loc,
EOpSubgroupBallotInclusiveBitCount, true, res, uintType);
break;
}
case EOpGlobalOrderedCountIncrement:
{
// NOTE: For HLSL SM6.0 this should work, but the current GLSL extensions don't allow this.
error(loc, "GlobalOrderedCountIncrement() unsupported", "GlobalOrderedCountIncrement", "");
break;
}
default:
break; // most pass through unchanged
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment