Commit ce443b3a by Neil Henning

Add support for the latest shader model 6.0 wave specs.

MSVC annoying totally redid their wave ops here: https://github.com/Microsoft/DirectXShaderCompiler/wiki/Wave-Intrinsics I've had to totally rework the existing wave ops to support this. ``` [----------] Global test environment tear-down [==========] 1142 tests from 44 test cases ran. (186552 ms total) [ PASSED ] 1142 tests. ```
parent d487d4d0
...@@ -619,6 +619,7 @@ spv::BuiltIn TGlslangToSpvTraverser::TranslateBuiltInDecoration(glslang::TBuiltI ...@@ -619,6 +619,7 @@ spv::BuiltIn TGlslangToSpvTraverser::TranslateBuiltInDecoration(glslang::TBuiltI
builder.addCapability(spv::CapabilityGroupNonUniform); builder.addCapability(spv::CapabilityGroupNonUniform);
builder.addCapability(spv::CapabilityGroupNonUniformBallot); builder.addCapability(spv::CapabilityGroupNonUniformBallot);
return spv::BuiltInSubgroupLtMask; return spv::BuiltInSubgroupLtMask;
#ifdef AMD_EXTENSIONS #ifdef AMD_EXTENSIONS
case glslang::EbvBaryCoordNoPersp: case glslang::EbvBaryCoordNoPersp:
builder.addExtension(spv::E_SPV_AMD_shader_explicit_vertex_parameter); builder.addExtension(spv::E_SPV_AMD_shader_explicit_vertex_parameter);
......
File mode changed from 100755 to 100644
...@@ -7,7 +7,7 @@ gl_FragCoord origin is upper left ...@@ -7,7 +7,7 @@ gl_FragCoord origin is upper left
0:? Sequence 0:? Sequence
0:3 Test condition and select ( temp void) 0:3 Test condition and select ( temp void)
0:3 Condition 0:3 Condition
0:3 '@gl_HelperInvocation' ( in bool HelperInvocation) 0:3 subgroupElect ( temp bool)
0:3 true case 0:3 true case
0:? Sequence 0:? Sequence
0:5 Branch: Return with expression 0:5 Branch: Return with expression
...@@ -45,7 +45,7 @@ gl_FragCoord origin is upper left ...@@ -45,7 +45,7 @@ gl_FragCoord origin is upper left
0:? Sequence 0:? Sequence
0:3 Test condition and select ( temp void) 0:3 Test condition and select ( temp void)
0:3 Condition 0:3 Condition
0:3 '@gl_HelperInvocation' ( in bool HelperInvocation) 0:3 subgroupElect ( temp bool)
0:3 true case 0:3 true case
0:? Sequence 0:? Sequence
0:5 Branch: Return with expression 0:5 Branch: Return with expression
...@@ -76,16 +76,15 @@ gl_FragCoord origin is upper left ...@@ -76,16 +76,15 @@ gl_FragCoord origin is upper left
// Id's are bound by 30 // Id's are bound by 30
Capability Shader Capability Shader
Capability GroupNonUniform
1: ExtInstImport "GLSL.std.450" 1: ExtInstImport "GLSL.std.450"
MemoryModel Logical GLSL450 MemoryModel Logical GLSL450
EntryPoint Fragment 4 "PixelShaderFunction" 13 28 EntryPoint Fragment 4 "PixelShaderFunction" 28
ExecutionMode 4 OriginUpperLeft ExecutionMode 4 OriginUpperLeft
Source HLSL 500 Source HLSL 500
Name 4 "PixelShaderFunction" Name 4 "PixelShaderFunction"
Name 9 "@PixelShaderFunction(" Name 9 "@PixelShaderFunction("
Name 13 "@gl_HelperInvocation"
Name 28 "@entryPointOutput" Name 28 "@entryPointOutput"
Decorate 13(@gl_HelperInvocation) BuiltIn HelperInvocation
Decorate 28(@entryPointOutput) Location 0 Decorate 28(@entryPointOutput) Location 0
2: TypeVoid 2: TypeVoid
3: TypeFunction 2 3: TypeFunction 2
...@@ -93,8 +92,8 @@ gl_FragCoord origin is upper left ...@@ -93,8 +92,8 @@ gl_FragCoord origin is upper left
7: TypeVector 6(float) 4 7: TypeVector 6(float) 4
8: TypeFunction 7(fvec4) 8: TypeFunction 7(fvec4)
11: TypeBool 11: TypeBool
12: TypePointer Input 11(bool) 12: TypeInt 32 0
13(@gl_HelperInvocation): 12(ptr) Variable Input 13: 12(int) Constant 3
17: 6(float) Constant 1065353216 17: 6(float) Constant 1065353216
18: 6(float) Constant 1073741824 18: 6(float) Constant 1073741824
19: 6(float) Constant 1077936128 19: 6(float) Constant 1077936128
...@@ -111,7 +110,7 @@ gl_FragCoord origin is upper left ...@@ -111,7 +110,7 @@ gl_FragCoord origin is upper left
FunctionEnd FunctionEnd
9(@PixelShaderFunction(): 7(fvec4) Function None 8 9(@PixelShaderFunction(): 7(fvec4) Function None 8
10: Label 10: Label
14: 11(bool) Load 13(@gl_HelperInvocation) 14: 11(bool) GroupNonUniformElect 13
SelectionMerge 16 None SelectionMerge 16 None
BranchConditional 14 15 23 BranchConditional 14 15 23
15: Label 15: Label
......
RWStructuredBuffer<uint> data;
[numthreads(32, 16, 1)]
void CSMain()
{
data[WaveGetOrderedIndex()] = 1;
}
float4 PixelShaderFunction() : COLOR0
{
if (0 == WaveGetOrderedIndex())
{
return float4(1, 2, 3, 4);
}
else
{
return float4(4, 3, 2, 1);
}
}
RWStructuredBuffer<uint> data;
[numthreads(32, 16, 1)]
void CSMain()
{
uint i = 42;
data[GlobalOrderedCountIncrement(i)] = 1;
}
float4 PixelShaderFunction() : COLOR0
{
uint i = 42;
if (0 == GlobalOrderedCountIncrement(i))
{
return float4(1, 2, 3, 4);
}
else
{
return float4(4, 3, 2, 1);
}
}
...@@ -50,4 +50,6 @@ void CSMain(uint3 dti : SV_DispatchThreadID) ...@@ -50,4 +50,6 @@ void CSMain(uint3 dti : SV_DispatchThreadID)
data[dti.x].d.x = WavePrefixProduct(data[dti.x].d.x); data[dti.x].d.x = WavePrefixProduct(data[dti.x].d.x);
data[dti.x].d.xy = WavePrefixProduct(data[dti.x].d.xy); data[dti.x].d.xy = WavePrefixProduct(data[dti.x].d.xy);
data[dti.x].d.xyz = WavePrefixProduct(data[dti.x].d.xyz); data[dti.x].d.xyz = WavePrefixProduct(data[dti.x].d.xyz);
data[dti.x].u.x = WavePrefixCountBits(data[dti.x].u.x == 0);
} }
...@@ -91,43 +91,63 @@ void CSMain(uint3 dti : SV_DispatchThreadID) ...@@ -91,43 +91,63 @@ void CSMain(uint3 dti : SV_DispatchThreadID)
data[dti.x].d.xy = QuadReadLaneAt(data[dti.x].d.xy, 3); data[dti.x].d.xy = QuadReadLaneAt(data[dti.x].d.xy, 3);
data[dti.x].d.xyz = QuadReadLaneAt(data[dti.x].d.xyz, 3); data[dti.x].d.xyz = QuadReadLaneAt(data[dti.x].d.xyz, 3);
data[dti.x].u = QuadSwapX(data[dti.x].u); data[dti.x].u = QuadReadAcrossX(data[dti.x].u);
data[dti.x].u.x = QuadSwapX(data[dti.x].u.x); data[dti.x].u.x = QuadReadAcrossX(data[dti.x].u.x);
data[dti.x].u.xy = QuadSwapX(data[dti.x].u.xy); data[dti.x].u.xy = QuadReadAcrossX(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadSwapX(data[dti.x].u.xyz); data[dti.x].u.xyz = QuadReadAcrossX(data[dti.x].u.xyz);
data[dti.x].i = QuadSwapX(data[dti.x].i); data[dti.x].i = QuadReadAcrossX(data[dti.x].i);
data[dti.x].i.x = QuadSwapX(data[dti.x].i.x); data[dti.x].i.x = QuadReadAcrossX(data[dti.x].i.x);
data[dti.x].i.xy = QuadSwapX(data[dti.x].i.xy); data[dti.x].i.xy = QuadReadAcrossX(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadSwapX(data[dti.x].i.xyz); data[dti.x].i.xyz = QuadReadAcrossX(data[dti.x].i.xyz);
data[dti.x].f = QuadSwapX(data[dti.x].f); data[dti.x].f = QuadReadAcrossX(data[dti.x].f);
data[dti.x].f.x = QuadSwapX(data[dti.x].f.x); data[dti.x].f.x = QuadReadAcrossX(data[dti.x].f.x);
data[dti.x].f.xy = QuadSwapX(data[dti.x].f.xy); data[dti.x].f.xy = QuadReadAcrossX(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadSwapX(data[dti.x].f.xyz); data[dti.x].f.xyz = QuadReadAcrossX(data[dti.x].f.xyz);
data[dti.x].d = QuadSwapX(data[dti.x].d); data[dti.x].d = QuadReadAcrossX(data[dti.x].d);
data[dti.x].d.x = QuadSwapX(data[dti.x].d.x); data[dti.x].d.x = QuadReadAcrossX(data[dti.x].d.x);
data[dti.x].d.xy = QuadSwapX(data[dti.x].d.xy); data[dti.x].d.xy = QuadReadAcrossX(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadSwapX(data[dti.x].d.xyz); data[dti.x].d.xyz = QuadReadAcrossX(data[dti.x].d.xyz);
data[dti.x].u = QuadSwapY(data[dti.x].u); data[dti.x].u = QuadReadAcrossY(data[dti.x].u);
data[dti.x].u.x = QuadSwapY(data[dti.x].u.x); data[dti.x].u.x = QuadReadAcrossY(data[dti.x].u.x);
data[dti.x].u.xy = QuadSwapY(data[dti.x].u.xy); data[dti.x].u.xy = QuadReadAcrossY(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadSwapY(data[dti.x].u.xyz); data[dti.x].u.xyz = QuadReadAcrossY(data[dti.x].u.xyz);
data[dti.x].i = QuadSwapY(data[dti.x].i); data[dti.x].i = QuadReadAcrossY(data[dti.x].i);
data[dti.x].i.x = QuadSwapY(data[dti.x].i.x); data[dti.x].i.x = QuadReadAcrossY(data[dti.x].i.x);
data[dti.x].i.xy = QuadSwapY(data[dti.x].i.xy); data[dti.x].i.xy = QuadReadAcrossY(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadSwapY(data[dti.x].i.xyz); data[dti.x].i.xyz = QuadReadAcrossY(data[dti.x].i.xyz);
data[dti.x].f = QuadSwapY(data[dti.x].f); data[dti.x].f = QuadReadAcrossY(data[dti.x].f);
data[dti.x].f.x = QuadSwapY(data[dti.x].f.x); data[dti.x].f.x = QuadReadAcrossY(data[dti.x].f.x);
data[dti.x].f.xy = QuadSwapY(data[dti.x].f.xy); data[dti.x].f.xy = QuadReadAcrossY(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadSwapY(data[dti.x].f.xyz); data[dti.x].f.xyz = QuadReadAcrossY(data[dti.x].f.xyz);
data[dti.x].d = QuadSwapY(data[dti.x].d); data[dti.x].d = QuadReadAcrossY(data[dti.x].d);
data[dti.x].d.x = QuadSwapY(data[dti.x].d.x); data[dti.x].d.x = QuadReadAcrossY(data[dti.x].d.x);
data[dti.x].d.xy = QuadSwapY(data[dti.x].d.xy); data[dti.x].d.xy = QuadReadAcrossY(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadSwapY(data[dti.x].d.xyz); data[dti.x].d.xyz = QuadReadAcrossY(data[dti.x].d.xyz);
data[dti.x].u = QuadReadAcrossDiagonal(data[dti.x].u);
data[dti.x].u.x = QuadReadAcrossDiagonal(data[dti.x].u.x);
data[dti.x].u.xy = QuadReadAcrossDiagonal(data[dti.x].u.xy);
data[dti.x].u.xyz = QuadReadAcrossDiagonal(data[dti.x].u.xyz);
data[dti.x].i = QuadReadAcrossDiagonal(data[dti.x].i);
data[dti.x].i.x = QuadReadAcrossDiagonal(data[dti.x].i.x);
data[dti.x].i.xy = QuadReadAcrossDiagonal(data[dti.x].i.xy);
data[dti.x].i.xyz = QuadReadAcrossDiagonal(data[dti.x].i.xyz);
data[dti.x].f = QuadReadAcrossDiagonal(data[dti.x].f);
data[dti.x].f.x = QuadReadAcrossDiagonal(data[dti.x].f.x);
data[dti.x].f.xy = QuadReadAcrossDiagonal(data[dti.x].f.xy);
data[dti.x].f.xyz = QuadReadAcrossDiagonal(data[dti.x].f.xyz);
data[dti.x].d = QuadReadAcrossDiagonal(data[dti.x].d);
data[dti.x].d.x = QuadReadAcrossDiagonal(data[dti.x].d.x);
data[dti.x].d.xy = QuadReadAcrossDiagonal(data[dti.x].d.xy);
data[dti.x].d.xyz = QuadReadAcrossDiagonal(data[dti.x].d.xyz);
} }
...@@ -3,5 +3,5 @@ RWStructuredBuffer<uint> data; ...@@ -3,5 +3,5 @@ RWStructuredBuffer<uint> data;
[numthreads(32, 16, 1)] [numthreads(32, 16, 1)]
void CSMain() void CSMain()
{ {
data[WaveGetLaneIndex()] = (WaveOnce()) ? WaveGetLaneCount() : 0; data[WaveGetLaneIndex()] = (WaveIsFirstLane()) ? WaveGetLaneCount() : 0;
} }
float4 PixelShaderFunction() : COLOR0 float4 PixelShaderFunction() : COLOR0
{ {
if (WaveIsHelperLane()) if (WaveIsFirstLane())
{ {
return float4(1, 2, 3, 4); return float4(1, 2, 3, 4);
} }
......
...@@ -11,113 +11,115 @@ RWStructuredBuffer<Types> data; ...@@ -11,113 +11,115 @@ RWStructuredBuffer<Types> data;
[numthreads(32, 16, 1)] [numthreads(32, 16, 1)]
void CSMain(uint3 dti : SV_DispatchThreadID) void CSMain(uint3 dti : SV_DispatchThreadID)
{ {
data[dti.x].u = WaveAllSum(data[dti.x].u); data[dti.x].u = WaveActiveSum(data[dti.x].u);
data[dti.x].u.x = WaveAllSum(data[dti.x].u.x); data[dti.x].u.x = WaveActiveSum(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllSum(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveSum(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllSum(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveSum(data[dti.x].u.xyz);
data[dti.x].i = WaveAllSum(data[dti.x].i); data[dti.x].i = WaveActiveSum(data[dti.x].i);
data[dti.x].i.x = WaveAllSum(data[dti.x].i.x); data[dti.x].i.x = WaveActiveSum(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllSum(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveSum(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllSum(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveSum(data[dti.x].i.xyz);
data[dti.x].f = WaveAllSum(data[dti.x].f); data[dti.x].f = WaveActiveSum(data[dti.x].f);
data[dti.x].f.x = WaveAllSum(data[dti.x].f.x); data[dti.x].f.x = WaveActiveSum(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllSum(data[dti.x].f.xy); data[dti.x].f.xy = WaveActiveSum(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllSum(data[dti.x].f.xyz); data[dti.x].f.xyz = WaveActiveSum(data[dti.x].f.xyz);
data[dti.x].d = WaveAllSum(data[dti.x].d); data[dti.x].d = WaveActiveSum(data[dti.x].d);
data[dti.x].d.x = WaveAllSum(data[dti.x].d.x); data[dti.x].d.x = WaveActiveSum(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllSum(data[dti.x].d.xy); data[dti.x].d.xy = WaveActiveSum(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllSum(data[dti.x].d.xyz); data[dti.x].d.xyz = WaveActiveSum(data[dti.x].d.xyz);
data[dti.x].u = WaveAllProduct(data[dti.x].u); data[dti.x].u = WaveActiveProduct(data[dti.x].u);
data[dti.x].u.x = WaveAllProduct(data[dti.x].u.x); data[dti.x].u.x = WaveActiveProduct(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllProduct(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveProduct(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllProduct(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveProduct(data[dti.x].u.xyz);
data[dti.x].i = WaveAllProduct(data[dti.x].i); data[dti.x].i = WaveActiveProduct(data[dti.x].i);
data[dti.x].i.x = WaveAllProduct(data[dti.x].i.x); data[dti.x].i.x = WaveActiveProduct(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllProduct(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveProduct(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllProduct(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveProduct(data[dti.x].i.xyz);
data[dti.x].f = WaveAllProduct(data[dti.x].f); data[dti.x].f = WaveActiveProduct(data[dti.x].f);
data[dti.x].f.x = WaveAllProduct(data[dti.x].f.x); data[dti.x].f.x = WaveActiveProduct(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllProduct(data[dti.x].f.xy); data[dti.x].f.xy = WaveActiveProduct(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllProduct(data[dti.x].f.xyz); data[dti.x].f.xyz = WaveActiveProduct(data[dti.x].f.xyz);
data[dti.x].d = WaveAllProduct(data[dti.x].d); data[dti.x].d = WaveActiveProduct(data[dti.x].d);
data[dti.x].d.x = WaveAllProduct(data[dti.x].d.x); data[dti.x].d.x = WaveActiveProduct(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllProduct(data[dti.x].d.xy); data[dti.x].d.xy = WaveActiveProduct(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllProduct(data[dti.x].d.xyz); data[dti.x].d.xyz = WaveActiveProduct(data[dti.x].d.xyz);
data[dti.x].u = WaveAllMin(data[dti.x].u); data[dti.x].u = WaveActiveMin(data[dti.x].u);
data[dti.x].u.x = WaveAllMin(data[dti.x].u.x); data[dti.x].u.x = WaveActiveMin(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllMin(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveMin(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllMin(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveMin(data[dti.x].u.xyz);
data[dti.x].i = WaveAllMin(data[dti.x].i); data[dti.x].i = WaveActiveMin(data[dti.x].i);
data[dti.x].i.x = WaveAllMin(data[dti.x].i.x); data[dti.x].i.x = WaveActiveMin(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllMin(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveMin(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllMin(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveMin(data[dti.x].i.xyz);
data[dti.x].f = WaveAllMin(data[dti.x].f); data[dti.x].f = WaveActiveMin(data[dti.x].f);
data[dti.x].f.x = WaveAllMin(data[dti.x].f.x); data[dti.x].f.x = WaveActiveMin(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllMin(data[dti.x].f.xy); data[dti.x].f.xy = WaveActiveMin(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllMin(data[dti.x].f.xyz); data[dti.x].f.xyz = WaveActiveMin(data[dti.x].f.xyz);
data[dti.x].d = WaveAllMin(data[dti.x].d); data[dti.x].d = WaveActiveMin(data[dti.x].d);
data[dti.x].d.x = WaveAllMin(data[dti.x].d.x); data[dti.x].d.x = WaveActiveMin(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllMin(data[dti.x].d.xy); data[dti.x].d.xy = WaveActiveMin(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllMin(data[dti.x].d.xyz); data[dti.x].d.xyz = WaveActiveMin(data[dti.x].d.xyz);
data[dti.x].u = WaveAllMax(data[dti.x].u); data[dti.x].u = WaveActiveMax(data[dti.x].u);
data[dti.x].u.x = WaveAllMax(data[dti.x].u.x); data[dti.x].u.x = WaveActiveMax(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllMax(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveMax(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllMax(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveMax(data[dti.x].u.xyz);
data[dti.x].i = WaveAllMax(data[dti.x].i); data[dti.x].i = WaveActiveMax(data[dti.x].i);
data[dti.x].i.x = WaveAllMax(data[dti.x].i.x); data[dti.x].i.x = WaveActiveMax(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllMax(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveMax(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllMax(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveMax(data[dti.x].i.xyz);
data[dti.x].f = WaveAllMax(data[dti.x].f); data[dti.x].f = WaveActiveMax(data[dti.x].f);
data[dti.x].f.x = WaveAllMax(data[dti.x].f.x); data[dti.x].f.x = WaveActiveMax(data[dti.x].f.x);
data[dti.x].f.xy = WaveAllMax(data[dti.x].f.xy); data[dti.x].f.xy = WaveActiveMax(data[dti.x].f.xy);
data[dti.x].f.xyz = WaveAllMax(data[dti.x].f.xyz); data[dti.x].f.xyz = WaveActiveMax(data[dti.x].f.xyz);
data[dti.x].d = WaveAllMax(data[dti.x].d); data[dti.x].d = WaveActiveMax(data[dti.x].d);
data[dti.x].d.x = WaveAllMax(data[dti.x].d.x); data[dti.x].d.x = WaveActiveMax(data[dti.x].d.x);
data[dti.x].d.xy = WaveAllMax(data[dti.x].d.xy); data[dti.x].d.xy = WaveActiveMax(data[dti.x].d.xy);
data[dti.x].d.xyz = WaveAllMax(data[dti.x].d.xyz); data[dti.x].d.xyz = WaveActiveMax(data[dti.x].d.xyz);
data[dti.x].u = WaveAllBitAnd(data[dti.x].u); data[dti.x].u = WaveActiveBitAnd(data[dti.x].u);
data[dti.x].u.x = WaveAllBitAnd(data[dti.x].u.x); data[dti.x].u.x = WaveActiveBitAnd(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllBitAnd(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveBitAnd(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllBitAnd(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveBitAnd(data[dti.x].u.xyz);
data[dti.x].i = WaveAllBitAnd(data[dti.x].i); data[dti.x].i = WaveActiveBitAnd(data[dti.x].i);
data[dti.x].i.x = WaveAllBitAnd(data[dti.x].i.x); data[dti.x].i.x = WaveActiveBitAnd(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllBitAnd(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveBitAnd(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllBitAnd(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveBitAnd(data[dti.x].i.xyz);
data[dti.x].u = WaveAllBitOr(data[dti.x].u); data[dti.x].u = WaveActiveBitOr(data[dti.x].u);
data[dti.x].u.x = WaveAllBitOr(data[dti.x].u.x); data[dti.x].u.x = WaveActiveBitOr(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllBitOr(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveBitOr(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllBitOr(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveBitOr(data[dti.x].u.xyz);
data[dti.x].i = WaveAllBitOr(data[dti.x].i); data[dti.x].i = WaveActiveBitOr(data[dti.x].i);
data[dti.x].i.x = WaveAllBitOr(data[dti.x].i.x); data[dti.x].i.x = WaveActiveBitOr(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllBitOr(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveBitOr(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllBitOr(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveBitOr(data[dti.x].i.xyz);
data[dti.x].u = WaveAllBitXor(data[dti.x].u); data[dti.x].u = WaveActiveBitXor(data[dti.x].u);
data[dti.x].u.x = WaveAllBitXor(data[dti.x].u.x); data[dti.x].u.x = WaveActiveBitXor(data[dti.x].u.x);
data[dti.x].u.xy = WaveAllBitXor(data[dti.x].u.xy); data[dti.x].u.xy = WaveActiveBitXor(data[dti.x].u.xy);
data[dti.x].u.xyz = WaveAllBitXor(data[dti.x].u.xyz); data[dti.x].u.xyz = WaveActiveBitXor(data[dti.x].u.xyz);
data[dti.x].i = WaveAllBitXor(data[dti.x].i); data[dti.x].i = WaveActiveBitXor(data[dti.x].i);
data[dti.x].i.x = WaveAllBitXor(data[dti.x].i.x); data[dti.x].i.x = WaveActiveBitXor(data[dti.x].i.x);
data[dti.x].i.xy = WaveAllBitXor(data[dti.x].i.xy); data[dti.x].i.xy = WaveActiveBitXor(data[dti.x].i.xy);
data[dti.x].i.xyz = WaveAllBitXor(data[dti.x].i.xyz); data[dti.x].i.xyz = WaveActiveBitXor(data[dti.x].i.xyz);
data[dti.x].u.x = WaveActiveCountBits(data[dti.x].u.x == 0);
} }
...@@ -3,7 +3,8 @@ RWStructuredBuffer<uint64_t> data; ...@@ -3,7 +3,8 @@ RWStructuredBuffer<uint64_t> data;
[numthreads(32, 16, 1)] [numthreads(32, 16, 1)]
void CSMain(uint3 dti : SV_DispatchThreadID) void CSMain(uint3 dti : SV_DispatchThreadID)
{ {
data[dti.x] = WaveBallot(WaveAnyTrue(dti.x == 0)); data[dti.x] = WaveActiveBallot(WaveActiveAnyTrue(dti.x == 0));
data[dti.y] = WaveBallot(WaveAllTrue(dti.y == 0)); data[dti.y] = WaveActiveBallot(WaveActiveAllTrue(dti.y == 0));
data[dti.z] = WaveBallot(WaveAllEqual(dti.z == 0)); data[dti.z] = WaveActiveBallot(WaveActiveAllEqualBool(dti.z == 0));
data[dti.z] = WaveActiveBallot(WaveActiveAllEqual(dti.z));
} }
...@@ -927,10 +927,8 @@ enum TOperator { ...@@ -927,10 +927,8 @@ enum TOperator {
// SM6 wave ops // SM6 wave ops
EOpWaveGetLaneCount, // Will decompose to gl_SubgroupSize. EOpWaveGetLaneCount, // Will decompose to gl_SubgroupSize.
EOpWaveGetLaneIndex, // Will decompose to gl_SubgroupInvocationID. EOpWaveGetLaneIndex, // Will decompose to gl_SubgroupInvocationID.
EOpWaveIsHelperLane, // Will decompose to gl_HelperInvocation. EOpWaveActiveCountBits, // Will decompose to subgroupBallotBitCount(subgroupBallot()).
EOpWaveBallot, // Will decompose to subgroupBallot. EOpWavePrefixCountBits, // Will decompose to subgroupBallotInclusiveBitCount(subgroupBallot()).
EOpWaveGetOrderedIndex, // Will decompose to an equation containing gl_SubgroupID.
EOpGlobalOrderedCountIncrement, // Will nice error.
}; };
class TIntermTraverser; class TIntermTraverser;
......
...@@ -367,17 +367,13 @@ INSTANTIATE_TEST_CASE_P( ...@@ -367,17 +367,13 @@ INSTANTIATE_TEST_CASE_P(
{"hlsl.type.identifier.frag", "main"}, {"hlsl.type.identifier.frag", "main"},
{"hlsl.typeGraphCopy.vert", "main"}, {"hlsl.typeGraphCopy.vert", "main"},
{"hlsl.typedef.frag", "PixelShaderFunction"}, {"hlsl.typedef.frag", "PixelShaderFunction"},
{"hlsl.wavequery.comp", "CSMain"},
{"hlsl.wavequery.frag", "PixelShaderFunction"},
{"hlsl.wavevote.comp", "CSMain"},
{"hlsl.wavebroadcast.comp", "CSMain"}, {"hlsl.wavebroadcast.comp", "CSMain"},
{"hlsl.wavereduction.comp", "CSMain"},
{"hlsl.waveprefix.comp", "CSMain"}, {"hlsl.waveprefix.comp", "CSMain"},
{"hlsl.wavequad.comp", "CSMain"}, {"hlsl.wavequad.comp", "CSMain"},
{"hlsl.waveordered.comp", "CSMain"}, {"hlsl.wavequery.comp", "CSMain"},
{"hlsl.waveordered2.comp", "CSMain"}, {"hlsl.wavequery.frag", "PixelShaderFunction"},
{"hlsl.waveordered.frag", "PixelShaderFunction"}, {"hlsl.wavereduction.comp", "CSMain"},
{"hlsl.waveordered2.frag", "PixelShaderFunction"}, {"hlsl.wavevote.comp", "CSMain"},
{"hlsl.whileLoop.frag", "PixelShaderFunction"}, {"hlsl.whileLoop.frag", "PixelShaderFunction"},
{"hlsl.void.frag", "PixelShaderFunction"} {"hlsl.void.frag", "PixelShaderFunction"}
}), }),
......
...@@ -5090,19 +5090,9 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*& ...@@ -5090,19 +5090,9 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*&
node = lookupBuiltinVariable("@gl_SubgroupInvocationID", EbvSubgroupInvocation2, type); node = lookupBuiltinVariable("@gl_SubgroupInvocationID", EbvSubgroupInvocation2, type);
break; break;
} }
case EOpWaveIsHelperLane: case EOpWaveActiveCountBits:
{ {
// Mapped to gl_HelperInvocation builtin (We preprend @ to the symbol // Mapped to subgroupBallotBitCount(subgroupBallot()) builtin
// so that it inhabits the symbol table, but has a user-invalid name
// in-case some source HLSL defined the symbol also).
TType type(EbtBool, EvqVaryingIn);
node = lookupBuiltinVariable("@gl_HelperInvocation", EbvHelperInvocation, type);
break;
}
case EOpWaveBallot:
{
// Mapped to subgroupBallot() builtin (NOTE: if an IHV has
// a subgroup size > 64 these wave ops will not work for them!)
// uvec4 type. // uvec4 type.
TType uvec4Type(EbtUint, EvqTemporary, 4); TType uvec4Type(EbtUint, EvqTemporary, 4);
...@@ -5111,63 +5101,34 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*& ...@@ -5111,63 +5101,34 @@ void HlslParseContext::decomposeIntrinsic(const TSourceLoc& loc, TIntermTyped*&
TIntermTyped* res = intermediate.addBuiltInFunctionCall(loc, TIntermTyped* res = intermediate.addBuiltInFunctionCall(loc,
EOpSubgroupBallot, true, arguments, uvec4Type); EOpSubgroupBallot, true, arguments, uvec4Type);
// And extract a uvec2 for the two highest components. // uint type.
TIntermTyped* xy = handleDotDereference(loc, res, "xy"); TType uintType(EbtUint, EvqTemporary);
// uint64_t type.
TType uint64Type(EbtUint64, EvqTemporary);
// And bitcast the result for a uint64_t
node = intermediate.addBuiltInFunctionCall(loc, node = intermediate.addBuiltInFunctionCall(loc,
EOpPackUint2x32, true, xy, uint64Type); EOpSubgroupBallotBitCount, true, res, uintType);
break; break;
} }
case EOpWaveGetOrderedIndex: case EOpWavePrefixCountBits:
{ {
if (language == EShLangFragment) { // Mapped to subgroupBallotInclusiveBitCount(subgroupBallot())
// NOTE: For HLSL SM6.0 this should work for PS too, but the current GLSL extensions don't allow this. // builtin
error(loc, "WaveGetOrderedIndex() unsupported in a pixel/fragment shader", "WaveGetOrderedIndex", "");
break;
}
TType uintType(EbtUint, EvqVaryingIn);
TIntermTyped* subgroupID = lookupBuiltinVariable("@gl_SubgroupID", EbvSubgroupID, uintType);
TIntermTyped* numSubgroups = lookupBuiltinVariable("@gl_NumSubgroups", EbvNumSubgroups, uintType);
TType uvec3Type(EbtUint, EvqVaryingIn, 3);
TIntermTyped* numWorkGroups = lookupBuiltinVariable("@gl_NumWorkGroups", EbvNumWorkGroups, uvec3Type);
TIntermTyped* workGroupID = lookupBuiltinVariable("@gl_WorkGroupID", EbvWorkGroupId, uvec3Type);
//x & y components of gl_NumWorkGroups // uvec4 type.
TIntermTyped* numWorkGroupsX = handleDotDereference(loc, numWorkGroups, "x"); TType uvec4Type(EbtUint, EvqTemporary, 4);
TIntermTyped* numWorkGroupsY = handleDotDereference(loc, numWorkGroups, "y");
// x & y components of globalSize // Get the uvec4 return from subgroupBallot().
TIntermTyped* globalSizeX = handleBinaryMath(loc, "mul", EOpMul, numSubgroups, numWorkGroupsX); TIntermTyped* res = intermediate.addBuiltInFunctionCall(loc,
TIntermTyped* globalSizeY = numWorkGroupsY; EOpSubgroupBallot, true, arguments, uvec4Type);
// x, y & z components of gl_WorkGroupID // uint type.
TIntermTyped* workGroupX = handleDotDereference(loc, workGroupID, "x"); TType uintType(EbtUint, EvqTemporary);
TIntermTyped* workGroupY = handleDotDereference(loc, workGroupID, "y");
TIntermTyped* workGroupZ = handleDotDereference(loc, workGroupID, "z");
// We're going to build up the following variables to get a uniquely ordered ID: node = intermediate.addBuiltInFunctionCall(loc,
// (globalSize.y * gl_WorkGroupID.z + gl_WorkGroupID.y) * globalSize.x + gl_WorkGroupID.x + gl_SubgroupID EOpSubgroupBallotInclusiveBitCount, true, res, uintType);
node = handleBinaryMath(loc, "mul", EOpMul, globalSizeY, workGroupZ);
node = handleBinaryMath(loc, "add", EOpAdd, node, workGroupY);
node = handleBinaryMath(loc, "mul", EOpMul, node, globalSizeX);
node = handleBinaryMath(loc, "add", EOpAdd, node, workGroupX);
node = handleBinaryMath(loc, "add", EOpAdd, node, subgroupID);
break; break;
} }
case EOpGlobalOrderedCountIncrement:
{
// NOTE: For HLSL SM6.0 this should work, but the current GLSL extensions don't allow this.
error(loc, "GlobalOrderedCountIncrement() unsupported", "GlobalOrderedCountIncrement", "");
break;
}
default: default:
break; // most pass through unchanged break; // most pass through unchanged
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment