diff --git a/naga/src/back/msl/mesh_shader.rs b/naga/src/back/msl/mesh_shader.rs index a172bde802d..e309c9ebbdc 100644 --- a/naga/src/back/msl/mesh_shader.rs +++ b/naga/src/back/msl/mesh_shader.rs @@ -217,7 +217,7 @@ impl super::Writer { writeln!(self.out, ") {{")?; // Function body - if ep.stage == crate::ShaderStage::Mesh { + if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task { for (handle, var) in module.global_variables.iter() { if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() { continue; @@ -252,7 +252,7 @@ impl super::Writer { is_first = false; write!(self.out, "{}", arg.name)?; } - if ep.stage == crate::ShaderStage::Mesh { + if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task { for (handle, var) in module.global_variables.iter() { if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() { continue; @@ -261,7 +261,7 @@ impl super::Writer { write!(self.out, ", ")?; } let name = &self.names[&NameKey::GlobalVariable(handle)]; - write!(self.out, "{name}")?; + write!(self.out, "&{name}")?; } } } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 584bf9f86be..746664d60b7 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -442,7 +442,7 @@ impl TypedGlobalVariable<'_> { }; let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) { (Some(space), crate::AddressSpace::WorkGroup) => { - ("", space, access, if self.reference { "&" } else { "" }) + ("", space, access, if self.reference { "*" } else { "" }) } (Some(space), _) if self.reference => { let coherent = if var @@ -3143,6 +3143,20 @@ impl Writer { Ok(check_written) } + fn is_root_workgroup_pointer( + &self, + chain: Handle, + context: &ExpressionContext, + ) -> bool { + match context.function.expressions[chain] { + crate::Expression::GlobalVariable(handle) => { + let var = &context.module.global_variables[handle]; + var.space == crate::AddressSpace::WorkGroup + } + _ => false, + } + } + /// Write the access chain `chain`. /// /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions, @@ -3201,13 +3215,22 @@ impl Writer { // indexing a struct with an expression. match *base_ty { crate::TypeInner::Struct { .. } => { + let is_workgroup = self.is_root_workgroup_pointer(base, context); + let op = if is_workgroup { "->" } else { "." }; let base_ty = base_ty_handle.unwrap(); self.put_access_chain(base, policy, context)?; let name = &self.names[&NameKey::StructMember(base_ty, index)]; - write!(self.out, ".{name}")?; + write!(self.out, "{op}{name}")?; } crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { + let is_workgroup_ptr = self.is_root_workgroup_pointer(base, context); + if is_workgroup_ptr { + write!(self.out, "(*")?; + } self.put_access_chain(base, policy, context)?; + if is_workgroup_ptr { + write!(self.out, ")")?; + } // Prior to Metal v2.1 component access for packed vectors wasn't available // however array indexing is if context.get_packed_vec_kind(base).is_some() { @@ -3267,9 +3290,18 @@ impl Writer { let accessing_wrapped_binding_array = matches!(*base_ty, crate::TypeInner::BindingArray { .. }); + let is_workgroup = self.is_root_workgroup_pointer(base, context); + + if is_workgroup && !accessing_wrapped_array { + write!(self.out, "(*")?; + } self.put_access_chain(base, policy, context)?; + if is_workgroup && !accessing_wrapped_array { + write!(self.out, ")")?; + } if accessing_wrapped_array { - write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; + let op = if is_workgroup { "->" } else { "." }; + write!(self.out, "{op}{WRAPPED_ARRAY_FIELD}")?; } write!(self.out, "[")?; @@ -3350,16 +3382,21 @@ impl Writer { .is_atomic_pointer(&context.module.types); if is_atomic_pointer { - write!( - self.out, - "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}" - )?; + write!(self.out, "{NAMESPACE}::atomic_load_explicit(")?; + let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context); + if !is_workgroup_ptr { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, context)?; write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; } else { // We don't do any dereferencing with `*` here as pointer arguments to functions // are done by `&` references and not `*` pointers. These do not need to be - // dereferenced. + // dereferenced, except for workgroups pointers. + let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context); + if is_workgroup_ptr { + write!(self.out, "*")?; + } self.put_access_chain(pointer, policy, context)?; } @@ -4006,9 +4043,13 @@ impl Writer { } // Put the atomic function invocation. + let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, context); match *fun { crate::AtomicFunction::Exchange { compare: Some(cmp) } => { - write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?; + write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}(")?; + if !is_workgroup_atomic { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; self.put_expression(cmp, context, true)?; @@ -4017,10 +4058,10 @@ impl Writer { write!(self.out, ")")?; } _ => { - write!( - self.out, - "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}" - )?; + write!(self.out, "{NAMESPACE}::atomic_{fun_key}_explicit(")?; + if !is_workgroup_atomic { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; self.put_expression(value, context, true)?; @@ -4274,16 +4315,21 @@ impl Writer { .is_atomic_pointer(&context.expression.module.types); if is_atomic_pointer { - write!( - self.out, - "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" - )?; + write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?; + let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, &context.expression); + if !is_workgroup_atomic { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } self.put_access_chain(pointer, policy, &context.expression)?; write!(self.out, ", ")?; self.put_expression(value, &context.expression, true)?; writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?; } else { write!(self.out, "{level}")?; + let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, &context.expression); + if is_workgroup_ptr { + write!(self.out, "*")?; + } self.put_access_chain(pointer, policy, &context.expression)?; write!(self.out, " = ")?; self.put_expression(value, &context.expression, true)?; @@ -7449,7 +7495,8 @@ template } _ => { if var.space == crate::AddressSpace::WorkGroup - && ep.stage == crate::ShaderStage::Mesh + && (ep.stage == crate::ShaderStage::Mesh + || ep.stage == crate::ShaderStage::Task) { continue; } @@ -7548,7 +7595,7 @@ template } writeln!(self.out)?; } - if ep.stage == crate::ShaderStage::Mesh { + if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task { for (handle, var) in module.global_variables.iter() { if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() { continue; @@ -7567,7 +7614,7 @@ template }; writeln!( self.out, - "threadgroup {ty_context}& {}", + "threadgroup {ty_context}* {}", self.names[&NameKey::GlobalVariable(handle)] )?; } @@ -8043,19 +8090,35 @@ mod workgroup_mem_init { } impl Access { + fn is_pointer_type(&self, module: &crate::Module) -> bool { + match *self { + Access::GlobalVariable(handle) => { + let var = &module.global_variables[handle]; + var.space == crate::AddressSpace::WorkGroup + } + Access::StructMember(..) | Access::Array(_) => false, + } + } + fn write( &self, writer: &mut W, names: &FastHashMap, + op: &str, ) -> Result<(), core::fmt::Error> { match *self { Access::GlobalVariable(handle) => { write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) } Access::StructMember(handle, index) => { - write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) + write!( + writer, + "{}{}", + op, + &names[&NameKey::StructMember(handle, index)] + ) } - Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"), + Access::Array(depth) => write!(writer, "{}{}[__i{depth}]", op, WRAPPED_ARRAY_FIELD), } } } @@ -8094,12 +8157,33 @@ mod workgroup_mem_init { &self, writer: &mut W, names: &FastHashMap, + module: &crate::Module, ) -> Result<(), core::fmt::Error> { - for next in self.stack.iter() { - next.write(writer, names)?; + for (i, next) in self.stack.iter().enumerate() { + let op = if i == 0 { + // root item doesn't get an operator prefix + "" + } else { + // check if the previous item is a pointer to determine the operator for this item + let prev = &self.stack[i - 1]; + if prev.is_pointer_type(module) { + "->" + } else { + "." + } + }; + next.write(writer, names, op)?; } Ok(()) } + + fn root_is_workgroup_pointer(&self, module: &crate::Module) -> bool { + if let Some(&Access::GlobalVariable(handle)) = self.stack.first() { + let var = &module.global_variables[handle]; + return var.space == crate::AddressSpace::WorkGroup; + } + false + } } impl Writer { @@ -8174,16 +8258,25 @@ mod workgroup_mem_init { ) -> BackendResult { if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { write!(self.out, "{level}")?; - access_stack.write(&mut self.out, &self.names)?; + // workgroup variables are always pointers; add * to dereference at root level. + // Nested accesses use -> operator from the access stack. + let is_root_workgroup = access_stack.root_is_workgroup_pointer(module); + let is_nested = access_stack.stack.len() > 1; + if is_root_workgroup && !is_nested { + write!(self.out, "*")?; + } + access_stack.write(&mut self.out, &self.names, module)?; writeln!(self.out, " = {{}};")?; } else { match module.types[ty].inner { crate::TypeInner::Atomic { .. } => { - write!( - self.out, - "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" - )?; - access_stack.write(&mut self.out, &self.names)?; + write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?; + // only skip & for direct access to workgroup atomic + let is_nested = access_stack.stack.len() > 1; + if !access_stack.root_is_workgroup_pointer(module) || is_nested { + write!(self.out, "{ATOMIC_REFERENCE}")?; + } + access_stack.write(&mut self.out, &self.names, module)?; writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; } crate::TypeInner::Array { base, size, .. } => { diff --git a/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal b/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal index 90411265c45..f9fdd8f5d49 100644 --- a/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal +++ b/naga/tests/out/msl/wgsl-8820-multiple-local-invocation-index-id.metal @@ -14,15 +14,15 @@ struct compute1_Input { kernel void compute1_( metal::uint3 local_invocation_id [[thread_position_in_threadgroup]] , uint local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup uint& wg_var +, threadgroup uint* wg_var ) { if (local_invocation_index == 0u) { - wg_var = {}; + *wg_var = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); const Input input = { local_invocation_id, local_invocation_index }; - wg_var = input.local_invocation_index * 2u; - uint _e6 = wg_var; - wg_var = _e6 + input.local_invocation_id[0]; + *wg_var = input.local_invocation_index * 2u; + uint _e6 = *wg_var; + *wg_var = _e6 + input.local_invocation_id[0]; return; } diff --git a/naga/tests/out/msl/wgsl-abstract-types-operators.metal b/naga/tests/out/msl/wgsl-abstract-types-operators.metal index bed023c0480..60e0ff32573 100644 --- a/naga/tests/out/msl/wgsl-abstract-types-operators.metal +++ b/naga/tests/out/msl/wgsl-abstract-types-operators.metal @@ -100,18 +100,18 @@ void wgpu_4445_( } void wgpu_4435_( - threadgroup type_3& a + threadgroup type_3* a ) { - uint y = a.inner[as_type(as_type(1) - as_type(1))]; + uint y = a->inner[as_type(as_type(1) - as_type(1))]; return; } kernel void main_( uint __local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup type_3& a +, threadgroup type_3* a ) { if (__local_invocation_index == 0u) { - a = {}; + *a = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); runtime_values(); diff --git a/naga/tests/out/msl/wgsl-atomicOps.metal b/naga/tests/out/msl/wgsl-atomicOps.metal index 26f384cf11f..dd7b7af39bf 100644 --- a/naga/tests/out/msl/wgsl-atomicOps.metal +++ b/naga/tests/out/msl/wgsl-atomicOps.metal @@ -80,18 +80,18 @@ kernel void cs_main( , device metal::atomic_uint& storage_atomic_scalar [[user(fake0)]] , device type_4& storage_atomic_arr [[user(fake0)]] , device Struct& storage_struct [[user(fake0)]] -, threadgroup metal::atomic_uint& workgroup_atomic_scalar -, threadgroup type_4& workgroup_atomic_arr -, threadgroup Struct& workgroup_struct +, threadgroup metal::atomic_uint* workgroup_atomic_scalar +, threadgroup type_4* workgroup_atomic_arr +, threadgroup Struct* workgroup_struct ) { if (__local_invocation_index == 0u) { - metal::atomic_store_explicit(&workgroup_atomic_scalar, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(workgroup_atomic_scalar, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 2; __i0++) { - metal::atomic_store_explicit(&workgroup_atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_atomic_arr->inner[__i0], 0, metal::memory_order_relaxed); } - metal::atomic_store_explicit(&workgroup_struct.atomic_scalar, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_scalar, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 2; __i0++) { - metal::atomic_store_explicit(&workgroup_struct.atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); } } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); @@ -99,97 +99,97 @@ kernel void cs_main( metal::atomic_store_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::atomic_store_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); metal::atomic_store_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - metal::atomic_store_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + metal::atomic_store_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + metal::atomic_store_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint l0_ = metal::atomic_load_explicit(&storage_atomic_scalar, metal::memory_order_relaxed); int l1_ = metal::atomic_load_explicit(&storage_atomic_arr.inner[1], metal::memory_order_relaxed); uint l2_ = metal::atomic_load_explicit(&storage_struct.atomic_scalar, metal::memory_order_relaxed); int l3_ = metal::atomic_load_explicit(&storage_struct.atomic_arr.inner[1], metal::memory_order_relaxed); - uint l4_ = metal::atomic_load_explicit(&workgroup_atomic_scalar, metal::memory_order_relaxed); - int l5_ = metal::atomic_load_explicit(&workgroup_atomic_arr.inner[1], metal::memory_order_relaxed); - uint l6_ = metal::atomic_load_explicit(&workgroup_struct.atomic_scalar, metal::memory_order_relaxed); - int l7_ = metal::atomic_load_explicit(&workgroup_struct.atomic_arr.inner[1], metal::memory_order_relaxed); + uint l4_ = metal::atomic_load_explicit(workgroup_atomic_scalar, metal::memory_order_relaxed); + int l5_ = metal::atomic_load_explicit(&workgroup_atomic_arr->inner[1], metal::memory_order_relaxed); + uint l6_ = metal::atomic_load_explicit(&workgroup_struct->atomic_scalar, metal::memory_order_relaxed); + int l7_ = metal::atomic_load_explicit(&workgroup_struct->atomic_arr.inner[1], metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e51 = metal::atomic_fetch_add_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e55 = metal::atomic_fetch_add_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e59 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e64 = metal::atomic_fetch_add_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e67 = metal::atomic_fetch_add_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e71 = metal::atomic_fetch_add_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e75 = metal::atomic_fetch_add_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e80 = metal::atomic_fetch_add_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e67 = metal::atomic_fetch_add_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e71 = metal::atomic_fetch_add_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e75 = metal::atomic_fetch_add_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e80 = metal::atomic_fetch_add_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e83 = metal::atomic_fetch_sub_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e87 = metal::atomic_fetch_sub_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e91 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e96 = metal::atomic_fetch_sub_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e99 = metal::atomic_fetch_sub_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e103 = metal::atomic_fetch_sub_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e107 = metal::atomic_fetch_sub_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e112 = metal::atomic_fetch_sub_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e99 = metal::atomic_fetch_sub_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e103 = metal::atomic_fetch_sub_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e107 = metal::atomic_fetch_sub_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e112 = metal::atomic_fetch_sub_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e115 = metal::atomic_fetch_max_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e119 = metal::atomic_fetch_max_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e123 = metal::atomic_fetch_max_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e128 = metal::atomic_fetch_max_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e131 = metal::atomic_fetch_max_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e135 = metal::atomic_fetch_max_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e139 = metal::atomic_fetch_max_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e144 = metal::atomic_fetch_max_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e131 = metal::atomic_fetch_max_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e135 = metal::atomic_fetch_max_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e139 = metal::atomic_fetch_max_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e144 = metal::atomic_fetch_max_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e147 = metal::atomic_fetch_min_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e151 = metal::atomic_fetch_min_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e155 = metal::atomic_fetch_min_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e160 = metal::atomic_fetch_min_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e163 = metal::atomic_fetch_min_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e167 = metal::atomic_fetch_min_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e171 = metal::atomic_fetch_min_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e176 = metal::atomic_fetch_min_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e163 = metal::atomic_fetch_min_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e167 = metal::atomic_fetch_min_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e171 = metal::atomic_fetch_min_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e176 = metal::atomic_fetch_min_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e179 = metal::atomic_fetch_and_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e183 = metal::atomic_fetch_and_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e187 = metal::atomic_fetch_and_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e192 = metal::atomic_fetch_and_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e195 = metal::atomic_fetch_and_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e199 = metal::atomic_fetch_and_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e203 = metal::atomic_fetch_and_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e208 = metal::atomic_fetch_and_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e195 = metal::atomic_fetch_and_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e199 = metal::atomic_fetch_and_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e203 = metal::atomic_fetch_and_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e208 = metal::atomic_fetch_and_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e211 = metal::atomic_fetch_or_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e215 = metal::atomic_fetch_or_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e219 = metal::atomic_fetch_or_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e224 = metal::atomic_fetch_or_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e227 = metal::atomic_fetch_or_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e231 = metal::atomic_fetch_or_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e235 = metal::atomic_fetch_or_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e240 = metal::atomic_fetch_or_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e227 = metal::atomic_fetch_or_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e231 = metal::atomic_fetch_or_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e235 = metal::atomic_fetch_or_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e240 = metal::atomic_fetch_or_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); uint _e243 = metal::atomic_fetch_xor_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e247 = metal::atomic_fetch_xor_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e251 = metal::atomic_fetch_xor_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e256 = metal::atomic_fetch_xor_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e259 = metal::atomic_fetch_xor_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e263 = metal::atomic_fetch_xor_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e267 = metal::atomic_fetch_xor_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e272 = metal::atomic_fetch_xor_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e259 = metal::atomic_fetch_xor_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e263 = metal::atomic_fetch_xor_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e267 = metal::atomic_fetch_xor_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e272 = metal::atomic_fetch_xor_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e275 = metal::atomic_exchange_explicit(&storage_atomic_scalar, 1u, metal::memory_order_relaxed); int _e279 = metal::atomic_exchange_explicit(&storage_atomic_arr.inner[1], 1, metal::memory_order_relaxed); uint _e283 = metal::atomic_exchange_explicit(&storage_struct.atomic_scalar, 1u, metal::memory_order_relaxed); int _e288 = metal::atomic_exchange_explicit(&storage_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e291 = metal::atomic_exchange_explicit(&workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); - int _e295 = metal::atomic_exchange_explicit(&workgroup_atomic_arr.inner[1], 1, metal::memory_order_relaxed); - uint _e299 = metal::atomic_exchange_explicit(&workgroup_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e304 = metal::atomic_exchange_explicit(&workgroup_struct.atomic_arr.inner[1], 1, metal::memory_order_relaxed); + uint _e291 = metal::atomic_exchange_explicit(workgroup_atomic_scalar, 1u, metal::memory_order_relaxed); + int _e295 = metal::atomic_exchange_explicit(&workgroup_atomic_arr->inner[1], 1, metal::memory_order_relaxed); + uint _e299 = metal::atomic_exchange_explicit(&workgroup_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e304 = metal::atomic_exchange_explicit(&workgroup_struct->atomic_arr.inner[1], 1, metal::memory_order_relaxed); _atomic_compare_exchange_result_Uint_4_ _e308 = naga_atomic_compare_exchange_weak_explicit(&storage_atomic_scalar, 1u, 2u); _atomic_compare_exchange_result_Sint_4_ _e313 = naga_atomic_compare_exchange_weak_explicit(&storage_atomic_arr.inner[1], 1, 2); _atomic_compare_exchange_result_Uint_4_ _e318 = naga_atomic_compare_exchange_weak_explicit(&storage_struct.atomic_scalar, 1u, 2u); _atomic_compare_exchange_result_Sint_4_ _e324 = naga_atomic_compare_exchange_weak_explicit(&storage_struct.atomic_arr.inner[1], 1, 2); - _atomic_compare_exchange_result_Uint_4_ _e328 = naga_atomic_compare_exchange_weak_explicit(&workgroup_atomic_scalar, 1u, 2u); - _atomic_compare_exchange_result_Sint_4_ _e333 = naga_atomic_compare_exchange_weak_explicit(&workgroup_atomic_arr.inner[1], 1, 2); - _atomic_compare_exchange_result_Uint_4_ _e338 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct.atomic_scalar, 1u, 2u); - _atomic_compare_exchange_result_Sint_4_ _e344 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct.atomic_arr.inner[1], 1, 2); + _atomic_compare_exchange_result_Uint_4_ _e328 = naga_atomic_compare_exchange_weak_explicit(workgroup_atomic_scalar, 1u, 2u); + _atomic_compare_exchange_result_Sint_4_ _e333 = naga_atomic_compare_exchange_weak_explicit(&workgroup_atomic_arr->inner[1], 1, 2); + _atomic_compare_exchange_result_Uint_4_ _e338 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct->atomic_scalar, 1u, 2u); + _atomic_compare_exchange_result_Sint_4_ _e344 = naga_atomic_compare_exchange_weak_explicit(&workgroup_struct->atomic_arr.inner[1], 1, 2); return; } diff --git a/naga/tests/out/msl/wgsl-globals.metal b/naga/tests/out/msl/wgsl-globals.metal index 8bd1a5cd377..1091ea45ac1 100644 --- a/naga/tests/out/msl/wgsl-globals.metal +++ b/naga/tests/out/msl/wgsl-globals.metal @@ -61,8 +61,8 @@ void test_msl_packed_vec3_( kernel void main_( uint __local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup type_2& wg -, threadgroup metal::atomic_uint& at_1 +, threadgroup type_2* wg +, threadgroup metal::atomic_uint* at_1 , device FooStruct& alignment [[user(fake0)]] , device type_6 const& dummy [[user(fake0)]] , constant type_8& float_vecs [[user(fake0)]] @@ -73,8 +73,8 @@ kernel void main_( , constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { if (__local_invocation_index == 0u) { - wg = {}; - metal::atomic_store_explicit(&at_1, 0, metal::memory_order_relaxed); + *wg = {}; + metal::atomic_store_explicit(at_1, 0, metal::memory_order_relaxed); } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); float Foo = 1.0; @@ -82,20 +82,20 @@ kernel void main_( test_msl_packed_vec3_(alignment); metal::float4x2 _e5 = global_nested_arrays_of_matrices_4x2_.inner[0].inner[0]; metal::float4 _e10 = global_nested_arrays_of_matrices_2x4_.inner[0].inner[0][0]; - wg.inner[7] = (_e5 * _e10).x; + wg->inner[7] = (_e5 * _e10).x; metal::float3x2 _e16 = global_mat; metal::float3 _e18 = global_vec; - wg.inner[6] = (_e16 * _e18).x; + wg->inner[6] = (_e16 * _e18).x; float _e26 = dummy[1].y; - wg.inner[5] = _e26; + wg->inner[5] = _e26; float _e32 = float_vecs.inner[0].w; - wg.inner[4] = _e32; + wg->inner[4] = _e32; float _e37 = alignment.v1_; - wg.inner[3] = _e37; + wg->inner[3] = _e37; float _e43 = alignment.v3_[0]; - wg.inner[2] = _e43; + wg->inner[2] = _e43; alignment.v1_ = 4.0; - wg.inner[1] = static_cast(1 + (_buffer_sizes.size3 - 0 - 8) / 8); - metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed); + wg->inner[1] = static_cast(1 + (_buffer_sizes.size3 - 0 - 8) / 8); + metal::atomic_store_explicit(at_1, 2u, metal::memory_order_relaxed); return; } diff --git a/naga/tests/out/msl/wgsl-interface.metal b/naga/tests/out/msl/wgsl-interface.metal index c4f63ca9ae7..8f4a8b11d18 100644 --- a/naga/tests/out/msl/wgsl-interface.metal +++ b/naga/tests/out/msl/wgsl-interface.metal @@ -75,13 +75,13 @@ struct computeInput { , uint local_index [[thread_index_in_threadgroup]] , metal::uint3 wg_id [[threadgroup_position_in_grid]] , metal::uint3 num_wgs [[threadgroups_per_grid]] -, threadgroup type_4& output +, threadgroup type_4* output ) { if (local_index == 0u) { - output = {}; + *output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - output.inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x; + output->inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x; return; } diff --git a/naga/tests/out/msl/wgsl-mesh-shader-empty.metal b/naga/tests/out/msl/wgsl-mesh-shader-empty.metal index 61e3b1d9b45..268e5554377 100644 --- a/naga/tests/out/msl/wgsl-mesh-shader-empty.metal +++ b/naga/tests/out/msl/wgsl-mesh-shader-empty.metal @@ -71,10 +71,10 @@ struct ms_mainPrimitiveOutput { void _ms_main( uint __local_invocation_index , object_data TaskPayload const& taskPayload -, threadgroup MeshOutput& mesh_output +, threadgroup MeshOutput* mesh_output ) { if (__local_invocation_index == 0u) { - mesh_output = {}; + *mesh_output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); @@ -86,7 +86,7 @@ void _ms_main( , object_data TaskPayload const& taskPayload [[payload]] ) { threadgroup MeshOutput mesh_output; - _ms_main(__local_invocation_index, taskPayload, mesh_output); + _ms_main(__local_invocation_index, taskPayload, &mesh_output); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); for(uint vertexIndex = __local_invocation_index; vertexIndex < metal::min(mesh_output.vertex_count, 3u); vertexIndex += 64) { diff --git a/naga/tests/out/msl/wgsl-mesh-shader-lines.metal b/naga/tests/out/msl/wgsl-mesh-shader-lines.metal index 504527a972e..475909ac8ef 100644 --- a/naga/tests/out/msl/wgsl-mesh-shader-lines.metal +++ b/naga/tests/out/msl/wgsl-mesh-shader-lines.metal @@ -70,10 +70,10 @@ struct ms_mainPrimitiveOutput { void _ms_main( uint __local_invocation_index , object_data TaskPayload const& taskPayload -, threadgroup MeshOutput& mesh_output +, threadgroup MeshOutput* mesh_output ) { if (__local_invocation_index == 0u) { - mesh_output = {}; + *mesh_output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); @@ -85,7 +85,7 @@ void _ms_main( , object_data TaskPayload const& taskPayload [[payload]] ) { threadgroup MeshOutput mesh_output; - _ms_main(__local_invocation_index, taskPayload, mesh_output); + _ms_main(__local_invocation_index, taskPayload, &mesh_output); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); for(uint vertexIndex = __local_invocation_index; vertexIndex < metal::min(mesh_output.vertex_count, 2u); vertexIndex += 64) { diff --git a/naga/tests/out/msl/wgsl-mesh-shader-points.metal b/naga/tests/out/msl/wgsl-mesh-shader-points.metal index 761e83d56e0..4663cd5100d 100644 --- a/naga/tests/out/msl/wgsl-mesh-shader-points.metal +++ b/naga/tests/out/msl/wgsl-mesh-shader-points.metal @@ -71,10 +71,10 @@ struct ms_mainPrimitiveOutput { void _ms_main( uint __local_invocation_index , object_data TaskPayload const& taskPayload -, threadgroup MeshOutput& mesh_output +, threadgroup MeshOutput* mesh_output ) { if (__local_invocation_index == 0u) { - mesh_output = {}; + *mesh_output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); @@ -86,7 +86,7 @@ void _ms_main( , object_data TaskPayload const& taskPayload [[payload]] ) { threadgroup MeshOutput mesh_output; - _ms_main(__local_invocation_index, taskPayload, mesh_output); + _ms_main(__local_invocation_index, taskPayload, &mesh_output); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); for(uint vertexIndex = __local_invocation_index; vertexIndex < metal::min(mesh_output.vertex_count, 1u); vertexIndex += 64) { diff --git a/naga/tests/out/msl/wgsl-mesh-shader.metal b/naga/tests/out/msl/wgsl-mesh-shader.metal index a359dd09110..7e57475bf84 100644 --- a/naga/tests/out/msl/wgsl-mesh-shader.metal +++ b/naga/tests/out/msl/wgsl-mesh-shader.metal @@ -54,15 +54,15 @@ void helper_writer( metal::uint3 _ts_main( uint __local_invocation_index , object_data TaskPayload& taskPayload -, threadgroup float& workgroupData +, threadgroup float* workgroupData ) { if (__local_invocation_index == 0u) { taskPayload = {}; - workgroupData = {}; + *workgroupData = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); - workgroupData = 1.0; + *workgroupData = 1.0; taskPayload.colorMask = metal::float4(1.0, 1.0, 0.0, 1.0); helper_writer(true, taskPayload); bool _e12 = helper_reader(taskPayload); @@ -74,9 +74,9 @@ metal::uint3 _ts_main( metal::mesh_grid_properties nagaMeshGrid , uint __local_invocation_index [[thread_index_in_threadgroup]] , object_data TaskPayload& taskPayload [[payload]] -, threadgroup float& workgroupData ) { - uint3 nagaGridSize = _ts_main(__local_invocation_index, taskPayload, workgroupData); + threadgroup float workgroupData; + uint3 nagaGridSize = _ts_main(__local_invocation_index, taskPayload, &workgroupData); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); if (__local_invocation_index == 0u) { @@ -151,31 +151,31 @@ struct ms_mainPrimitiveOutput { void _ms_main( uint __local_invocation_index , object_data TaskPayload const& taskPayload -, threadgroup float& workgroupData -, threadgroup MeshOutput& mesh_output +, threadgroup float* workgroupData +, threadgroup MeshOutput* mesh_output ) { if (__local_invocation_index == 0u) { - workgroupData = {}; - mesh_output = {}; + *workgroupData = {}; + *mesh_output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); - mesh_output.vertex_count = 3u; - mesh_output.primitive_count = 1u; - workgroupData = 2.0; - mesh_output.vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); + mesh_output->vertex_count = 3u; + mesh_output->primitive_count = 1u; + *workgroupData = 2.0; + mesh_output->vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); metal::float4 _e23 = taskPayload.colorMask; - mesh_output.vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0) * _e23; - mesh_output.vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); + mesh_output->vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0) * _e23; + mesh_output->vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); metal::float4 _e45 = taskPayload.colorMask; - mesh_output.vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0) * _e45; - mesh_output.vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); + mesh_output->vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0) * _e45; + mesh_output->vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); metal::float4 _e67 = taskPayload.colorMask; - mesh_output.vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0) * _e67; - mesh_output.primitives.inner[0].indices = metal::uint3(0u, 1u, 2u); + mesh_output->vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0) * _e67; + mesh_output->primitives.inner[0].indices = metal::uint3(0u, 1u, 2u); bool _e86 = helper_reader(taskPayload); - mesh_output.primitives.inner[0].cull = !(_e86); - mesh_output.primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); + mesh_output->primitives.inner[0].cull = !(_e86); + mesh_output->primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); return; } @@ -186,7 +186,7 @@ void _ms_main( ) { threadgroup float workgroupData; threadgroup MeshOutput mesh_output; - _ms_main(__local_invocation_index, taskPayload, workgroupData, mesh_output); + _ms_main(__local_invocation_index, taskPayload, &workgroupData, &mesh_output); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); for(uint vertexIndex = __local_invocation_index; vertexIndex < metal::min(mesh_output.vertex_count, 3u); vertexIndex += 1) { @@ -219,27 +219,27 @@ struct ms_no_tsPrimitiveOutput { }; void _ms_no_ts( uint __local_invocation_index -, threadgroup float& workgroupData -, threadgroup MeshOutput& mesh_output +, threadgroup float* workgroupData +, threadgroup MeshOutput* mesh_output ) { if (__local_invocation_index == 0u) { - workgroupData = {}; - mesh_output = {}; + *workgroupData = {}; + *mesh_output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); - mesh_output.vertex_count = 3u; - mesh_output.primitive_count = 1u; - workgroupData = 2.0; - mesh_output.vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); - mesh_output.vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0); - mesh_output.vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); - mesh_output.vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0); - mesh_output.vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); - mesh_output.vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0); - mesh_output.primitives.inner[0].indices = metal::uint3(0u, 1u, 2u); - mesh_output.primitives.inner[0].cull = false; - mesh_output.primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); + mesh_output->vertex_count = 3u; + mesh_output->primitive_count = 1u; + *workgroupData = 2.0; + mesh_output->vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); + mesh_output->vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0); + mesh_output->vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); + mesh_output->vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0); + mesh_output->vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); + mesh_output->vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0); + mesh_output->primitives.inner[0].indices = metal::uint3(0u, 1u, 2u); + mesh_output->primitives.inner[0].cull = false; + mesh_output->primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); return; } @@ -249,7 +249,7 @@ void _ms_no_ts( ) { threadgroup float workgroupData; threadgroup MeshOutput mesh_output; - _ms_no_ts(__local_invocation_index, workgroupData, mesh_output); + _ms_no_ts(__local_invocation_index, &workgroupData, &mesh_output); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); for(uint vertexIndex_1 = __local_invocation_index; vertexIndex_1 < metal::min(mesh_output.vertex_count, 3u); vertexIndex_1 += 1) { @@ -285,28 +285,28 @@ struct ms_divergentPrimitiveOutput { void _ms_divergent( metal::uint3 thread_id_1 , uint __local_invocation_index -, threadgroup float& workgroupData -, threadgroup MeshOutput& mesh_output +, threadgroup float* workgroupData +, threadgroup MeshOutput* mesh_output ) { if (__local_invocation_index == 0u) { - workgroupData = {}; - mesh_output = {}; + *workgroupData = {}; + *mesh_output = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); if (thread_id_1.x == 0u) { - mesh_output.vertex_count = 3u; - mesh_output.primitive_count = 1u; - workgroupData = 2.0; - mesh_output.vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); - mesh_output.vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0); - mesh_output.vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); - mesh_output.vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0); - mesh_output.vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); - mesh_output.vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0); - mesh_output.primitives.inner[0].indices = metal::uint3(0u, 1u, 2u); - mesh_output.primitives.inner[0].cull = false; - mesh_output.primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); + mesh_output->vertex_count = 3u; + mesh_output->primitive_count = 1u; + *workgroupData = 2.0; + mesh_output->vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); + mesh_output->vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0); + mesh_output->vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); + mesh_output->vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0); + mesh_output->vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); + mesh_output->vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0); + mesh_output->primitives.inner[0].indices = metal::uint3(0u, 1u, 2u); + mesh_output->primitives.inner[0].cull = false; + mesh_output->primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); return; } else { return; @@ -320,7 +320,7 @@ void _ms_divergent( ) { threadgroup float workgroupData; threadgroup MeshOutput mesh_output; - _ms_divergent(thread_id_1, __local_invocation_index, workgroupData, mesh_output); + _ms_divergent(thread_id_1, __local_invocation_index, &workgroupData, &mesh_output); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_object_data); for(uint vertexIndex_2 = __local_invocation_index; vertexIndex_2 < metal::min(mesh_output.vertex_count, 3u); vertexIndex_2 += 2) { diff --git a/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal b/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal index 8380883e370..90e3607d5ae 100644 --- a/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal +++ b/naga/tests/out/msl/wgsl-overrides-atomicCompareExchangeWeak.metal @@ -38,12 +38,12 @@ constant int o = 2; kernel void f( uint __local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup metal::atomic_uint& a +, threadgroup metal::atomic_uint* a ) { if (__local_invocation_index == 0u) { - metal::atomic_store_explicit(&a, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(a, 0, metal::memory_order_relaxed); } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - _atomic_compare_exchange_result_Uint_4_ _e5 = naga_atomic_compare_exchange_weak_explicit(&a, 2u, 1u); + _atomic_compare_exchange_result_Uint_4_ _e5 = naga_atomic_compare_exchange_weak_explicit(a, 2u, 1u); return; } diff --git a/naga/tests/out/msl/wgsl-policy-mix.metal b/naga/tests/out/msl/wgsl-policy-mix.metal index a31d80398f7..4e4dca43158 100644 --- a/naga/tests/out/msl/wgsl-policy-mix.metal +++ b/naga/tests/out/msl/wgsl-policy-mix.metal @@ -39,14 +39,14 @@ metal::float4 mock_function( device InStorage const& in_storage, constant InUniform& in_uniform, metal::texture2d_array image_2d_array, - threadgroup type_5& in_workgroup, + threadgroup type_5* in_workgroup, thread type_6& in_private ) { type_9 in_function = type_9 {{metal::float4(0.707, 0.0, 0.0, 1.0), metal::float4(0.0, 0.707, 0.0, 1.0)}}; metal::float4 _e18 = in_storage.a.inner[i]; metal::float4 _e22 = in_uniform.a.inner[i]; metal::float4 _e25 = (uint(l) < image_2d_array.get_num_mip_levels() && uint(i) < image_2d_array.get_array_size() && metal::all(metal::uint2(c) < metal::uint2(image_2d_array.get_width(l), image_2d_array.get_height(l))) ? image_2d_array.read(metal::uint2(c), i, l): DefaultConstructible()); - float _e29 = in_workgroup.inner[metal::min(unsigned(i), 29u)]; + float _e29 = in_workgroup->inner[metal::min(unsigned(i), 29u)]; float _e34 = in_private.inner[metal::min(unsigned(i), 39u)]; metal::float4 _e38 = in_function.inner[metal::min(unsigned(i), 1u)]; return ((((_e18 + _e22) + _e25) + metal::float4(_e29)) + metal::float4(_e34)) + _e38; @@ -57,11 +57,11 @@ kernel void main_( , device InStorage const& in_storage [[user(fake0)]] , constant InUniform& in_uniform [[user(fake0)]] , metal::texture2d_array image_2d_array [[user(fake0)]] -, threadgroup type_5& in_workgroup +, threadgroup type_5* in_workgroup ) { type_6 in_private = {}; if (__local_invocation_index == 0u) { - in_workgroup = {}; + *in_workgroup = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::float4 _e5 = mock_function(metal::int2(1, 2), 3, 4, in_storage, in_uniform, image_2d_array, in_workgroup, in_private); diff --git a/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal b/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal index ed4b600967d..68ed05f8b82 100644 --- a/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal +++ b/naga/tests/out/msl/wgsl-workgroup-uniform-load-atomic.metal @@ -18,16 +18,16 @@ kernel void test_atomic_workgroup_uniform_load( metal::uint3 workgroup_id [[threadgroup_position_in_grid]] , metal::uint3 local_id [[thread_position_in_threadgroup]] , uint __local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup metal::atomic_uint& wg_scalar -, threadgroup metal::atomic_int& wg_signed -, threadgroup AtomicStruct& wg_struct +, threadgroup metal::atomic_uint* wg_scalar +, threadgroup metal::atomic_int* wg_signed +, threadgroup AtomicStruct* wg_struct ) { if (__local_invocation_index == 0u) { - metal::atomic_store_explicit(&wg_scalar, 0, metal::memory_order_relaxed); - metal::atomic_store_explicit(&wg_signed, 0, metal::memory_order_relaxed); - metal::atomic_store_explicit(&wg_struct.atomic_scalar, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(wg_scalar, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(wg_signed, 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&wg_struct->atomic_scalar, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 2; __i0++) { - metal::atomic_store_explicit(&wg_struct.atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&wg_struct->atomic_arr.inner[__i0], 0, metal::memory_order_relaxed); } } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); @@ -35,22 +35,22 @@ kernel void test_atomic_workgroup_uniform_load( bool local_1 = {}; bool local_2 = {}; uint active_tile_index = workgroup_id.x + (workgroup_id.y * 32768u); - uint _e11 = metal::atomic_fetch_or_explicit(&wg_scalar, static_cast(active_tile_index >= 64u), metal::memory_order_relaxed); - int _e14 = metal::atomic_fetch_add_explicit(&wg_signed, 1, metal::memory_order_relaxed); - metal::atomic_store_explicit(&wg_struct.atomic_scalar, 1u, metal::memory_order_relaxed); - int _e22 = metal::atomic_fetch_add_explicit(&wg_struct.atomic_arr.inner[0], 1, metal::memory_order_relaxed); + uint _e11 = metal::atomic_fetch_or_explicit(wg_scalar, static_cast(active_tile_index >= 64u), metal::memory_order_relaxed); + int _e14 = metal::atomic_fetch_add_explicit(wg_signed, 1, metal::memory_order_relaxed); + metal::atomic_store_explicit(&wg_struct->atomic_scalar, 1u, metal::memory_order_relaxed); + int _e22 = metal::atomic_fetch_add_explicit(&wg_struct->atomic_arr.inner[0], 1, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - uint unnamed = metal::atomic_load_explicit(&wg_scalar, metal::memory_order_relaxed); + uint unnamed = metal::atomic_load_explicit(wg_scalar, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - int unnamed_1 = metal::atomic_load_explicit(&wg_signed, metal::memory_order_relaxed); + int unnamed_1 = metal::atomic_load_explicit(wg_signed, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - uint unnamed_2 = metal::atomic_load_explicit(&wg_struct.atomic_scalar, metal::memory_order_relaxed); + uint unnamed_2 = metal::atomic_load_explicit(&wg_struct->atomic_scalar, metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - int unnamed_3 = metal::atomic_load_explicit(&wg_struct.atomic_arr.inner[0], metal::memory_order_relaxed); + int unnamed_3 = metal::atomic_load_explicit(&wg_struct->atomic_arr.inner[0], metal::memory_order_relaxed); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); if (unnamed == 0u) { local = unnamed_1 > 0; diff --git a/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal b/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal index 5b8b513c36d..63167e7fc05 100644 --- a/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal +++ b/naga/tests/out/msl/wgsl-workgroup-uniform-load.metal @@ -14,14 +14,14 @@ struct test_workgroupUniformLoadInput { kernel void test_workgroupUniformLoad( metal::uint3 workgroup_id [[threadgroup_position_in_grid]] , uint __local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup type_2& arr_i32_ +, threadgroup type_2* arr_i32_ ) { if (__local_invocation_index == 0u) { - arr_i32_ = {}; + *arr_i32_ = {}; } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - int unnamed = arr_i32_.inner[workgroup_id.x]; + int unnamed = arr_i32_->inner[workgroup_id.x]; metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); if (unnamed > 10) { metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); diff --git a/naga/tests/out/msl/wgsl-workgroup-var-init.metal b/naga/tests/out/msl/wgsl-workgroup-var-init.metal index 6bb6bf96f6e..7e10ec9fd1c 100644 --- a/naga/tests/out/msl/wgsl-workgroup-var-init.metal +++ b/naga/tests/out/msl/wgsl-workgroup-var-init.metal @@ -21,20 +21,20 @@ struct WStruct { kernel void main_( uint __local_invocation_index [[thread_index_in_threadgroup]] -, threadgroup WStruct& w_mem +, threadgroup WStruct* w_mem , device type_1& output [[buffer(0)]] ) { if (__local_invocation_index == 0u) { - w_mem.arr = {}; - metal::atomic_store_explicit(&w_mem.atom, 0, metal::memory_order_relaxed); + w_mem->arr = {}; + metal::atomic_store_explicit(&w_mem->atom, 0, metal::memory_order_relaxed); for (int __i0 = 0; __i0 < 8; __i0++) { for (int __i1 = 0; __i1 < 8; __i1++) { - metal::atomic_store_explicit(&w_mem.atom_arr.inner[__i0].inner[__i1], 0, metal::memory_order_relaxed); + metal::atomic_store_explicit(&w_mem->atom_arr.inner[__i0].inner[__i1], 0, metal::memory_order_relaxed); } } } metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup); - type_1 _e3 = w_mem.arr; + type_1 _e3 = w_mem->arr; output = _e3; return; }