Skip to content

Commit 1d520a0

Browse files
authored
Fix workgroupUniformLoad returning atomic types (gfx-rs#8791)
1 parent bed71ef commit 1d520a0

12 files changed

+525
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Bottom level categories:
6868

6969
- Reject zero-value construction of a runtime-sized array with a validation error. Previously it would crash in the HLSL backend. By @mooori in [#8741](https://github.com/gfx-rs/wgpu/pull/8741).
7070
- Reject splat vector construction if the argument type does not match the type of the vector's scalar. Previously it would succeed. By @mooori in [#8829](https://github.com/gfx-rs/wgpu/pull/8829).
71+
- Fixed `workgroupUniformLoad` incorrectly returning an atomic when called on an atomic, it now returns the inner `T` as per the spec. By @cryvosh in [#8791](https://github.com/gfx-rs/wgpu/pull/8791).
7172

7273
### Documentation
7374

cts_runner/test.lst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ webgpu:shader,execution,expression,call,builtin,textureSampleBaseClampToEdge:2d_
220220
// NOTE: This is supposed to be an exhaustive listing underneath
221221
// `webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:*`, so exceptions can be
222222
// worked around.
223+
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="atomic%3Cu32%3E";*
224+
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="atomic%3Ci32%3E";*
225+
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="AtomicInStruct";*
223226
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="bool";*
224227
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="u32";*
225228
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="vec4u";*
@@ -228,10 +231,6 @@ webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type=
228231
webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="SimpleStruct";*
229232
//FAIL: https://github.com/gfx-rs/wgpu/issues/8812
230233
// webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="ComplexStruct";*
231-
//FAIL: https://github.com/gfx-rs/wgpu/pull/8791
232-
// webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="atomic%3Cu32%3E";*
233-
// webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="atomic%3Ci32%3E";*
234-
// webgpu:shader,execution,expression,call,builtin,workgroupUniformLoad:types:type="AtomicInStruct";*
235234
webgpu:shader,execution,flow_control,return:*
236235
// Many other vertex_buffer_access subtests also passing, but there are too many to enumerate.
237236
// Fails on Metal in CI only, not when running locally.

naga/src/back/spv/block.rs

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4122,43 +4122,15 @@ impl BlockContext<'_> {
41224122
self.writer
41234123
.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
41244124
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
4125-
// Embed the body of
4126-
match self.write_access_chain(
4125+
// Match `Expression::Load` behavior, including `OpAtomicLoad` when
4126+
// loading from a pointer to `atomic<T>`.
4127+
let id = self.write_checked_load(
41274128
pointer,
41284129
&mut block,
41294130
AccessTypeAdjustment::None,
4130-
)? {
4131-
ExpressionPointer::Ready { pointer_id } => {
4132-
let id = self.gen_id();
4133-
block.body.push(Instruction::load(
4134-
result_type_id,
4135-
id,
4136-
pointer_id,
4137-
None,
4138-
));
4139-
self.cached[result] = id;
4140-
}
4141-
ExpressionPointer::Conditional { condition, access } => {
4142-
self.cached[result] = self.write_conditional_indexed_load(
4143-
result_type_id,
4144-
condition,
4145-
&mut block,
4146-
move |id_gen, block| {
4147-
// The in-bounds path. Perform the access and the load.
4148-
let pointer_id = access.result_id.unwrap();
4149-
let value_id = id_gen.next();
4150-
block.body.push(access);
4151-
block.body.push(Instruction::load(
4152-
result_type_id,
4153-
value_id,
4154-
pointer_id,
4155-
None,
4156-
));
4157-
value_id
4158-
},
4159-
)
4160-
}
4161-
}
4131+
result_type_id,
4132+
)?;
4133+
self.cached[result] = id;
41624134
self.writer
41634135
.write_control_barrier(crate::Barrier::WORK_GROUP, &mut block.body);
41644136
}

naga/src/front/wgsl/lower/mod.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3041,9 +3041,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
30413041
ir::TypeInner::Pointer {
30423042
base,
30433043
space: ir::AddressSpace::WorkGroup,
3044-
} => base,
3045-
ref other => {
3046-
log::error!("Type {other:?} passed to workgroupUniformLoad");
3044+
} => match ctx.module.types[base].inner {
3045+
// Match `Expression::Load` semantics:
3046+
// loading through a pointer to `atomic<T>` produces a `T`.
3047+
ir::TypeInner::Atomic(scalar) => ctx.module.types.insert(
3048+
ir::Type {
3049+
name: None,
3050+
inner: ir::TypeInner::Scalar(scalar),
3051+
},
3052+
span,
3053+
),
3054+
_ => base,
3055+
},
3056+
ir::TypeInner::ValuePointer {
3057+
size,
3058+
scalar,
3059+
space: ir::AddressSpace::WorkGroup,
3060+
} => ctx.module.types.insert(
3061+
ir::Type {
3062+
name: None,
3063+
inner: match size {
3064+
Some(size) => ir::TypeInner::Vector { size, scalar },
3065+
None => ir::TypeInner::Scalar(scalar),
3066+
},
3067+
},
3068+
span,
3069+
),
3070+
_ => {
30473071
let span = ctx.ast_expressions.get_span(expr);
30483072
return Err(Box::new(Error::InvalidWorkGroupUniformLoad(span)));
30493073
}

naga/src/valid/function.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,23 @@ impl super::Validator {
14791479
base: ty,
14801480
space: AddressSpace::WorkGroup,
14811481
};
1482-
if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types) {
1482+
// workgroupUniformLoad on atomic<T> returns T, not atomic<T>.
1483+
// Verify the pointer's atomic scalar matches the result scalar.
1484+
let atomic_specialization_ok = match *pointer_inner {
1485+
Ti::Pointer {
1486+
base: pointer_base,
1487+
space: AddressSpace::WorkGroup,
1488+
} => match (&context.types[pointer_base].inner, &context.types[ty].inner) {
1489+
(&Ti::Atomic(pointer_scalar), &Ti::Scalar(result_scalar)) => {
1490+
pointer_scalar == result_scalar
1491+
}
1492+
_ => false,
1493+
},
1494+
_ => false,
1495+
};
1496+
if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types)
1497+
&& !atomic_specialization_ok
1498+
{
14831499
return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
14841500
.with_span_static(span, "WorkGroupUniformLoad"));
14851501
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Test workgroupUniformLoad specialization for atomic<T> -> T
2+
3+
struct AtomicStruct {
4+
atomic_scalar: atomic<u32>,
5+
atomic_arr: array<atomic<i32>, 2>,
6+
}
7+
8+
var<workgroup> wg_scalar: atomic<u32>;
9+
var<workgroup> wg_signed: atomic<i32>;
10+
var<workgroup> wg_struct: AtomicStruct;
11+
12+
@compute @workgroup_size(64)
13+
fn test_atomic_workgroup_uniform_load(
14+
@builtin(workgroup_id) workgroup_id: vec3u,
15+
@builtin(local_invocation_id) local_id: vec3u
16+
) {
17+
let active_tile_index = workgroup_id.x + workgroup_id.y * 32768;
18+
19+
// Each thread may set the atomics
20+
atomicOr(&wg_scalar, u32(active_tile_index >= 64));
21+
atomicAdd(&wg_signed, 1i);
22+
atomicStore(&wg_struct.atomic_scalar, 1u);
23+
atomicAdd(&wg_struct.atomic_arr[0], 1i);
24+
25+
workgroupBarrier();
26+
27+
// workgroupUniformLoad on atomic<u32> should return u32
28+
let scalar_val: u32 = workgroupUniformLoad(&wg_scalar);
29+
30+
// workgroupUniformLoad on atomic<i32> should return i32
31+
let signed_val: i32 = workgroupUniformLoad(&wg_signed);
32+
33+
// workgroupUniformLoad on struct.atomic_scalar should return u32
34+
let struct_scalar: u32 = workgroupUniformLoad(&wg_struct.atomic_scalar);
35+
36+
// workgroupUniformLoad on struct.atomic_arr[i] should return i32
37+
let struct_arr_val: i32 = workgroupUniformLoad(&wg_struct.atomic_arr[0]);
38+
39+
// Should be able to use all results in comparisons
40+
if scalar_val == 0u && signed_val > 0i && struct_scalar > 0u && struct_arr_val > 0i {
41+
return;
42+
}
43+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#version 310 es
2+
3+
precision highp float;
4+
precision highp int;
5+
6+
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
7+
8+
struct AtomicStruct {
9+
uint atomic_scalar;
10+
int atomic_arr[2];
11+
};
12+
shared uint wg_scalar;
13+
14+
shared int wg_signed;
15+
16+
shared AtomicStruct wg_struct;
17+
18+
19+
void main() {
20+
if (gl_LocalInvocationID == uvec3(0u)) {
21+
wg_scalar = 0u;
22+
wg_signed = 0;
23+
wg_struct = AtomicStruct(0u, int[2](0, 0));
24+
}
25+
memoryBarrierShared();
26+
barrier();
27+
uvec3 workgroup_id = gl_WorkGroupID;
28+
uvec3 local_id = gl_LocalInvocationID;
29+
bool local = false;
30+
bool local_1 = false;
31+
bool local_2 = false;
32+
uint active_tile_index = (workgroup_id.x + (workgroup_id.y * 32768u));
33+
uint _e11 = atomicOr(wg_scalar, uint((active_tile_index >= 64u)));
34+
int _e14 = atomicAdd(wg_signed, 1);
35+
wg_struct.atomic_scalar = 1u;
36+
int _e22 = atomicAdd(wg_struct.atomic_arr[0], 1);
37+
memoryBarrierShared();
38+
barrier();
39+
memoryBarrierShared();
40+
barrier();
41+
uint _e24 = wg_scalar;
42+
memoryBarrierShared();
43+
barrier();
44+
memoryBarrierShared();
45+
barrier();
46+
int _e26 = wg_signed;
47+
memoryBarrierShared();
48+
barrier();
49+
memoryBarrierShared();
50+
barrier();
51+
uint _e29 = wg_struct.atomic_scalar;
52+
memoryBarrierShared();
53+
barrier();
54+
memoryBarrierShared();
55+
barrier();
56+
int _e33 = wg_struct.atomic_arr[0];
57+
memoryBarrierShared();
58+
barrier();
59+
if ((_e24 == 0u)) {
60+
local = (_e26 > 0);
61+
} else {
62+
local = false;
63+
}
64+
bool _e41 = local;
65+
if (_e41) {
66+
local_1 = (_e29 > 0u);
67+
} else {
68+
local_1 = false;
69+
}
70+
bool _e47 = local_1;
71+
if (_e47) {
72+
local_2 = (_e33 > 0);
73+
} else {
74+
local_2 = false;
75+
}
76+
bool _e53 = local_2;
77+
if (_e53) {
78+
return;
79+
} else {
80+
return;
81+
}
82+
}
83+
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
struct AtomicStruct {
2+
uint atomic_scalar;
3+
int atomic_arr[2];
4+
};
5+
6+
groupshared uint wg_scalar;
7+
groupshared int wg_signed;
8+
groupshared AtomicStruct wg_struct;
9+
10+
[numthreads(64, 1, 1)]
11+
void test_atomic_workgroup_uniform_load(uint3 workgroup_id : SV_GroupID, uint3 local_id : SV_GroupThreadID, uint3 __local_invocation_id : SV_GroupThreadID)
12+
{
13+
if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {
14+
wg_scalar = (uint)0;
15+
wg_signed = (int)0;
16+
wg_struct = (AtomicStruct)0;
17+
}
18+
GroupMemoryBarrierWithGroupSync();
19+
bool local = (bool)0;
20+
bool local_1 = (bool)0;
21+
bool local_2 = (bool)0;
22+
23+
uint active_tile_index = (workgroup_id.x + (workgroup_id.y * 32768u));
24+
uint _e11; InterlockedOr(wg_scalar, uint((active_tile_index >= 64u)), _e11);
25+
int _e14; InterlockedAdd(wg_signed, int(1), _e14);
26+
wg_struct.atomic_scalar = 1u;
27+
int _e22; InterlockedAdd(wg_struct.atomic_arr[0], int(1), _e22);
28+
GroupMemoryBarrierWithGroupSync();
29+
GroupMemoryBarrierWithGroupSync();
30+
uint _e24 = wg_scalar;
31+
GroupMemoryBarrierWithGroupSync();
32+
GroupMemoryBarrierWithGroupSync();
33+
int _e26 = wg_signed;
34+
GroupMemoryBarrierWithGroupSync();
35+
GroupMemoryBarrierWithGroupSync();
36+
uint _e29 = wg_struct.atomic_scalar;
37+
GroupMemoryBarrierWithGroupSync();
38+
GroupMemoryBarrierWithGroupSync();
39+
int _e33 = wg_struct.atomic_arr[0];
40+
GroupMemoryBarrierWithGroupSync();
41+
if ((_e24 == 0u)) {
42+
local = (_e26 > int(0));
43+
} else {
44+
local = false;
45+
}
46+
bool _e41 = local;
47+
if (_e41) {
48+
local_1 = (_e29 > 0u);
49+
} else {
50+
local_1 = false;
51+
}
52+
bool _e47 = local_1;
53+
if (_e47) {
54+
local_2 = (_e33 > int(0));
55+
} else {
56+
local_2 = false;
57+
}
58+
bool _e53 = local_2;
59+
if (_e53) {
60+
return;
61+
} else {
62+
return;
63+
}
64+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
(
2+
vertex:[
3+
],
4+
fragment:[
5+
],
6+
compute:[
7+
(
8+
entry_point:"test_atomic_workgroup_uniform_load",
9+
target_profile:"cs_5_1",
10+
),
11+
],
12+
)

0 commit comments

Comments
 (0)