-
Notifications
You must be signed in to change notification settings - Fork 1.3k
use threadgroup pointers instead of references in metal #9380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: trunk
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -442,7 +442,7 @@ impl TypedGlobalVariable<'_> { | |
| }; | ||
| let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) { | ||
| (Some(space), crate::AddressSpace::WorkGroup) => { | ||
| ("", space, access, if self.reference { "&" } else { "" }) | ||
| ("", space, access, if self.reference { "*" } else { "" }) | ||
| } | ||
| (Some(space), _) if self.reference => { | ||
| let coherent = if var | ||
|
|
@@ -3143,6 +3143,20 @@ impl<W: Write> Writer<W> { | |
| Ok(check_written) | ||
| } | ||
|
|
||
| fn is_root_workgroup_pointer( | ||
| &self, | ||
| chain: Handle<crate::Expression>, | ||
| context: &ExpressionContext, | ||
| ) -> bool { | ||
| match context.function.expressions[chain] { | ||
| crate::Expression::GlobalVariable(handle) => { | ||
| let var = &context.module.global_variables[handle]; | ||
| var.space == crate::AddressSpace::WorkGroup | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, should this match for task payload variables? |
||
| } | ||
| _ => false, | ||
| } | ||
| } | ||
|
|
||
| /// Write the access chain `chain`. | ||
| /// | ||
| /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions, | ||
|
|
@@ -3201,13 +3215,22 @@ impl<W: Write> Writer<W> { | |
| // indexing a struct with an expression. | ||
| match *base_ty { | ||
| crate::TypeInner::Struct { .. } => { | ||
| let is_workgroup = self.is_root_workgroup_pointer(base, context); | ||
| let op = if is_workgroup { "->" } else { "." }; | ||
| let base_ty = base_ty_handle.unwrap(); | ||
| self.put_access_chain(base, policy, context)?; | ||
| let name = &self.names[&NameKey::StructMember(base_ty, index)]; | ||
| write!(self.out, ".{name}")?; | ||
| write!(self.out, "{op}{name}")?; | ||
| } | ||
| crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { | ||
| let is_workgroup_ptr = self.is_root_workgroup_pointer(base, context); | ||
| if is_workgroup_ptr { | ||
| write!(self.out, "(*")?; | ||
| } | ||
| self.put_access_chain(base, policy, context)?; | ||
| if is_workgroup_ptr { | ||
| write!(self.out, ")")?; | ||
| } | ||
| // Prior to Metal v2.1 component access for packed vectors wasn't available | ||
| // however array indexing is | ||
| if context.get_packed_vec_kind(base).is_some() { | ||
|
|
@@ -3267,9 +3290,18 @@ impl<W: Write> Writer<W> { | |
| let accessing_wrapped_binding_array = | ||
| matches!(*base_ty, crate::TypeInner::BindingArray { .. }); | ||
|
|
||
| let is_workgroup = self.is_root_workgroup_pointer(base, context); | ||
|
|
||
| if is_workgroup && !accessing_wrapped_array { | ||
| write!(self.out, "(*")?; | ||
| } | ||
| self.put_access_chain(base, policy, context)?; | ||
| if is_workgroup && !accessing_wrapped_array { | ||
| write!(self.out, ")")?; | ||
| } | ||
| if accessing_wrapped_array { | ||
| write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; | ||
| let op = if is_workgroup { "->" } else { "." }; | ||
| write!(self.out, "{op}{WRAPPED_ARRAY_FIELD}")?; | ||
| } | ||
| write!(self.out, "[")?; | ||
|
|
||
|
|
@@ -3350,16 +3382,21 @@ impl<W: Write> Writer<W> { | |
| .is_atomic_pointer(&context.module.types); | ||
|
|
||
| if is_atomic_pointer { | ||
| write!( | ||
| self.out, | ||
| "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}" | ||
| )?; | ||
| write!(self.out, "{NAMESPACE}::atomic_load_explicit(")?; | ||
| let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context); | ||
| if !is_workgroup_ptr { | ||
| write!(self.out, "{ATOMIC_REFERENCE}")?; | ||
| } | ||
| self.put_access_chain(pointer, policy, context)?; | ||
| write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; | ||
| } else { | ||
| // We don't do any dereferencing with `*` here as pointer arguments to functions | ||
| // are done by `&` references and not `*` pointers. These do not need to be | ||
| // dereferenced. | ||
| // dereferenced, except for workgroups pointers. | ||
| let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context); | ||
| if is_workgroup_ptr { | ||
| write!(self.out, "*")?; | ||
| } | ||
| self.put_access_chain(pointer, policy, context)?; | ||
| } | ||
|
|
||
|
|
@@ -4006,9 +4043,13 @@ impl<W: Write> Writer<W> { | |
| } | ||
|
|
||
| // Put the atomic function invocation. | ||
| let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, context); | ||
| match *fun { | ||
| crate::AtomicFunction::Exchange { compare: Some(cmp) } => { | ||
| write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?; | ||
| write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}(")?; | ||
| if !is_workgroup_atomic { | ||
| write!(self.out, "{ATOMIC_REFERENCE}")?; | ||
| } | ||
| self.put_access_chain(pointer, policy, context)?; | ||
| write!(self.out, ", ")?; | ||
| self.put_expression(cmp, context, true)?; | ||
|
|
@@ -4017,10 +4058,10 @@ impl<W: Write> Writer<W> { | |
| write!(self.out, ")")?; | ||
| } | ||
| _ => { | ||
| write!( | ||
| self.out, | ||
| "{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}" | ||
| )?; | ||
| write!(self.out, "{NAMESPACE}::atomic_{fun_key}_explicit(")?; | ||
| if !is_workgroup_atomic { | ||
| write!(self.out, "{ATOMIC_REFERENCE}")?; | ||
| } | ||
| self.put_access_chain(pointer, policy, context)?; | ||
| write!(self.out, ", ")?; | ||
| self.put_expression(value, context, true)?; | ||
|
|
@@ -4274,16 +4315,21 @@ impl<W: Write> Writer<W> { | |
| .is_atomic_pointer(&context.expression.module.types); | ||
|
|
||
| if is_atomic_pointer { | ||
| write!( | ||
| self.out, | ||
| "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" | ||
| )?; | ||
| write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?; | ||
| let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, &context.expression); | ||
| if !is_workgroup_atomic { | ||
| write!(self.out, "{ATOMIC_REFERENCE}")?; | ||
| } | ||
| self.put_access_chain(pointer, policy, &context.expression)?; | ||
| write!(self.out, ", ")?; | ||
| self.put_expression(value, &context.expression, true)?; | ||
| writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?; | ||
| } else { | ||
| write!(self.out, "{level}")?; | ||
| let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, &context.expression); | ||
| if is_workgroup_ptr { | ||
| write!(self.out, "*")?; | ||
| } | ||
| self.put_access_chain(pointer, policy, &context.expression)?; | ||
| write!(self.out, " = ")?; | ||
| self.put_expression(value, &context.expression, true)?; | ||
|
|
@@ -7449,7 +7495,8 @@ template <typename A> | |
| } | ||
| _ => { | ||
| if var.space == crate::AddressSpace::WorkGroup | ||
| && ep.stage == crate::ShaderStage::Mesh | ||
| && (ep.stage == crate::ShaderStage::Mesh | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the kind of stuff that I highlighted in the mesh shader ouptut but I don't know why it needed to change |
||
| || ep.stage == crate::ShaderStage::Task) | ||
| { | ||
| continue; | ||
| } | ||
|
|
@@ -7548,7 +7595,7 @@ template <typename A> | |
| } | ||
| writeln!(self.out)?; | ||
| } | ||
| if ep.stage == crate::ShaderStage::Mesh { | ||
| if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task { | ||
| for (handle, var) in module.global_variables.iter() { | ||
| if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() { | ||
| continue; | ||
|
|
@@ -7567,7 +7614,7 @@ template <typename A> | |
| }; | ||
| writeln!( | ||
| self.out, | ||
| "threadgroup {ty_context}& {}", | ||
| "threadgroup {ty_context}* {}", | ||
| self.names[&NameKey::GlobalVariable(handle)] | ||
| )?; | ||
| } | ||
|
|
@@ -8043,19 +8090,35 @@ mod workgroup_mem_init { | |
| } | ||
|
|
||
| impl Access { | ||
| fn is_pointer_type(&self, module: &crate::Module) -> bool { | ||
| match *self { | ||
| Access::GlobalVariable(handle) => { | ||
| let var = &module.global_variables[handle]; | ||
| var.space == crate::AddressSpace::WorkGroup | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again here, should this match a task payload pointer |
||
| } | ||
| Access::StructMember(..) | Access::Array(_) => false, | ||
| } | ||
| } | ||
|
|
||
| fn write<W: Write>( | ||
| &self, | ||
| writer: &mut W, | ||
| names: &FastHashMap<NameKey, String>, | ||
| op: &str, | ||
| ) -> Result<(), core::fmt::Error> { | ||
| match *self { | ||
| Access::GlobalVariable(handle) => { | ||
| write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) | ||
| } | ||
| Access::StructMember(handle, index) => { | ||
| write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) | ||
| write!( | ||
| writer, | ||
| "{}{}", | ||
| op, | ||
| &names[&NameKey::StructMember(handle, index)] | ||
| ) | ||
| } | ||
| Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"), | ||
| Access::Array(depth) => write!(writer, "{}{}[__i{depth}]", op, WRAPPED_ARRAY_FIELD), | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -8094,12 +8157,33 @@ mod workgroup_mem_init { | |
| &self, | ||
| writer: &mut W, | ||
| names: &FastHashMap<NameKey, String>, | ||
| module: &crate::Module, | ||
| ) -> Result<(), core::fmt::Error> { | ||
| for next in self.stack.iter() { | ||
| next.write(writer, names)?; | ||
| for (i, next) in self.stack.iter().enumerate() { | ||
| let op = if i == 0 { | ||
| // root item doesn't get an operator prefix | ||
| "" | ||
| } else { | ||
| // check if the previous item is a pointer to determine the operator for this item | ||
| let prev = &self.stack[i - 1]; | ||
| if prev.is_pointer_type(module) { | ||
| "->" | ||
| } else { | ||
| "." | ||
| } | ||
| }; | ||
| next.write(writer, names, op)?; | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn root_is_workgroup_pointer(&self, module: &crate::Module) -> bool { | ||
| if let Some(&Access::GlobalVariable(handle)) = self.stack.first() { | ||
| let var = &module.global_variables[handle]; | ||
| return var.space == crate::AddressSpace::WorkGroup; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You know the deal |
||
| } | ||
| false | ||
| } | ||
| } | ||
|
|
||
| impl<W: Write> Writer<W> { | ||
|
|
@@ -8174,16 +8258,25 @@ mod workgroup_mem_init { | |
| ) -> BackendResult { | ||
| if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { | ||
| write!(self.out, "{level}")?; | ||
| access_stack.write(&mut self.out, &self.names)?; | ||
| // workgroup variables are always pointers; add * to dereference at root level. | ||
| // Nested accesses use -> operator from the access stack. | ||
| let is_root_workgroup = access_stack.root_is_workgroup_pointer(module); | ||
| let is_nested = access_stack.stack.len() > 1; | ||
| if is_root_workgroup && !is_nested { | ||
| write!(self.out, "*")?; | ||
| } | ||
| access_stack.write(&mut self.out, &self.names, module)?; | ||
| writeln!(self.out, " = {{}};")?; | ||
| } else { | ||
| match module.types[ty].inner { | ||
| crate::TypeInner::Atomic { .. } => { | ||
| write!( | ||
| self.out, | ||
| "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" | ||
| )?; | ||
| access_stack.write(&mut self.out, &self.names)?; | ||
| write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?; | ||
| // only skip & for direct access to workgroup atomic | ||
| let is_nested = access_stack.stack.len() > 1; | ||
| if !access_stack.root_is_workgroup_pointer(module) || is_nested { | ||
| write!(self.out, "{ATOMIC_REFERENCE}")?; | ||
| } | ||
| access_stack.write(&mut self.out, &self.names, module)?; | ||
| writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; | ||
| } | ||
| crate::TypeInner::Array { base, size, .. } => { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also match for TaskPayload address space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it should, it doesn't map to the threadgroup address space.
wgpu/naga/src/back/msl/writer.rs
Line 688 in 253477b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Task payload storage class is defined to be basically identical to threadgroup storage class, except that in mesh shaders it is immutable. My point here is that this fix should therefore also apply to object_data variables so that they can benefit.