Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions naga/src/back/msl/mesh_shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl<W: core::fmt::Write> super::Writer<W> {
writeln!(self.out, ") {{")?;

// Function body
if ep.stage == crate::ShaderStage::Mesh {
if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task {
for (handle, var) in module.global_variables.iter() {
if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
continue;
Expand Down Expand Up @@ -252,7 +252,7 @@ impl<W: core::fmt::Write> super::Writer<W> {
is_first = false;
write!(self.out, "{}", arg.name)?;
}
if ep.stage == crate::ShaderStage::Mesh {
if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task {
for (handle, var) in module.global_variables.iter() {
if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
continue;
Expand All @@ -261,7 +261,7 @@ impl<W: core::fmt::Write> super::Writer<W> {
write!(self.out, ", ")?;
}
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{name}")?;
write!(self.out, "&{name}")?;
}
}
}
Expand Down
153 changes: 123 additions & 30 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ impl TypedGlobalVariable<'_> {
};
let (coherent, space, access, reference) = match (var.space.to_msl_name(), var.space) {
(Some(space), crate::AddressSpace::WorkGroup) => {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also match for TaskPayload address space?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should, it doesn't map to the threadgroup address space.

Self::TaskPayload => Some("object_data"),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Task payload storage class is defined to be basically identical to threadgroup storage class, except that in mesh shaders it is immutable. My point here is that this fix should therefore also apply to object_data variables so that they can benefit.

("", space, access, if self.reference { "&" } else { "" })
("", space, access, if self.reference { "*" } else { "" })
}
(Some(space), _) if self.reference => {
let coherent = if var
Expand Down Expand Up @@ -3143,6 +3143,20 @@ impl<W: Write> Writer<W> {
Ok(check_written)
}

fn is_root_workgroup_pointer(
&self,
chain: Handle<crate::Expression>,
context: &ExpressionContext,
) -> bool {
match context.function.expressions[chain] {
crate::Expression::GlobalVariable(handle) => {
let var = &context.module.global_variables[handle];
var.space == crate::AddressSpace::WorkGroup
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, should this match for task payload variables?

}
_ => false,
}
}

/// Write the access chain `chain`.
///
/// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
Expand Down Expand Up @@ -3201,13 +3215,22 @@ impl<W: Write> Writer<W> {
// indexing a struct with an expression.
match *base_ty {
crate::TypeInner::Struct { .. } => {
let is_workgroup = self.is_root_workgroup_pointer(base, context);
let op = if is_workgroup { "->" } else { "." };
let base_ty = base_ty_handle.unwrap();
self.put_access_chain(base, policy, context)?;
let name = &self.names[&NameKey::StructMember(base_ty, index)];
write!(self.out, ".{name}")?;
write!(self.out, "{op}{name}")?;
}
crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
let is_workgroup_ptr = self.is_root_workgroup_pointer(base, context);
if is_workgroup_ptr {
write!(self.out, "(*")?;
}
self.put_access_chain(base, policy, context)?;
if is_workgroup_ptr {
write!(self.out, ")")?;
}
// Prior to Metal v2.1 component access for packed vectors wasn't available
// however array indexing is
if context.get_packed_vec_kind(base).is_some() {
Expand Down Expand Up @@ -3267,9 +3290,18 @@ impl<W: Write> Writer<W> {
let accessing_wrapped_binding_array =
matches!(*base_ty, crate::TypeInner::BindingArray { .. });

let is_workgroup = self.is_root_workgroup_pointer(base, context);

if is_workgroup && !accessing_wrapped_array {
write!(self.out, "(*")?;
}
self.put_access_chain(base, policy, context)?;
if is_workgroup && !accessing_wrapped_array {
write!(self.out, ")")?;
}
if accessing_wrapped_array {
write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
let op = if is_workgroup { "->" } else { "." };
write!(self.out, "{op}{WRAPPED_ARRAY_FIELD}")?;
}
write!(self.out, "[")?;

Expand Down Expand Up @@ -3350,16 +3382,21 @@ impl<W: Write> Writer<W> {
.is_atomic_pointer(&context.module.types);

if is_atomic_pointer {
write!(
self.out,
"{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
)?;
write!(self.out, "{NAMESPACE}::atomic_load_explicit(")?;
let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context);
if !is_workgroup_ptr {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
} else {
// We don't do any dereferencing with `*` here as pointer arguments to functions
// are done by `&` references and not `*` pointers. These do not need to be
// dereferenced.
// dereferenced, except for workgroups pointers.
let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, context);
if is_workgroup_ptr {
write!(self.out, "*")?;
}
self.put_access_chain(pointer, policy, context)?;
}

Expand Down Expand Up @@ -4006,9 +4043,13 @@ impl<W: Write> Writer<W> {
}

// Put the atomic function invocation.
let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, context);
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}({ATOMIC_REFERENCE}")?;
write!(self.out, "{ATOMIC_COMP_EXCH_FUNCTION}(")?;
if !is_workgroup_atomic {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(cmp, context, true)?;
Expand All @@ -4017,10 +4058,10 @@ impl<W: Write> Writer<W> {
write!(self.out, ")")?;
}
_ => {
write!(
self.out,
"{NAMESPACE}::atomic_{fun_key}_explicit({ATOMIC_REFERENCE}"
)?;
write!(self.out, "{NAMESPACE}::atomic_{fun_key}_explicit(")?;
if !is_workgroup_atomic {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
self.put_expression(value, context, true)?;
Expand Down Expand Up @@ -4274,16 +4315,21 @@ impl<W: Write> Writer<W> {
.is_atomic_pointer(&context.expression.module.types);

if is_atomic_pointer {
write!(
self.out,
"{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
)?;
write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?;
let is_workgroup_atomic = self.is_root_workgroup_pointer(pointer, &context.expression);
if !is_workgroup_atomic {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
self.put_access_chain(pointer, policy, &context.expression)?;
write!(self.out, ", ")?;
self.put_expression(value, &context.expression, true)?;
writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
} else {
write!(self.out, "{level}")?;
let is_workgroup_ptr = self.is_root_workgroup_pointer(pointer, &context.expression);
if is_workgroup_ptr {
write!(self.out, "*")?;
}
self.put_access_chain(pointer, policy, &context.expression)?;
write!(self.out, " = ")?;
self.put_expression(value, &context.expression, true)?;
Expand Down Expand Up @@ -7449,7 +7495,8 @@ template <typename A>
}
_ => {
if var.space == crate::AddressSpace::WorkGroup
&& ep.stage == crate::ShaderStage::Mesh
&& (ep.stage == crate::ShaderStage::Mesh
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the kind of stuff that I highlighted in the mesh shader ouptut but I don't know why it needed to change

|| ep.stage == crate::ShaderStage::Task)
{
continue;
}
Expand Down Expand Up @@ -7548,7 +7595,7 @@ template <typename A>
}
writeln!(self.out)?;
}
if ep.stage == crate::ShaderStage::Mesh {
if ep.stage == crate::ShaderStage::Mesh || ep.stage == crate::ShaderStage::Task {
for (handle, var) in module.global_variables.iter() {
if var.space != crate::AddressSpace::WorkGroup || fun_info[handle].is_empty() {
continue;
Expand All @@ -7567,7 +7614,7 @@ template <typename A>
};
writeln!(
self.out,
"threadgroup {ty_context}& {}",
"threadgroup {ty_context}* {}",
self.names[&NameKey::GlobalVariable(handle)]
)?;
}
Expand Down Expand Up @@ -8043,19 +8090,35 @@ mod workgroup_mem_init {
}

impl Access {
fn is_pointer_type(&self, module: &crate::Module) -> bool {
match *self {
Access::GlobalVariable(handle) => {
let var = &module.global_variables[handle];
var.space == crate::AddressSpace::WorkGroup
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again here, should this match a task payload pointer

}
Access::StructMember(..) | Access::Array(_) => false,
}
}

fn write<W: Write>(
&self,
writer: &mut W,
names: &FastHashMap<NameKey, String>,
op: &str,
) -> Result<(), core::fmt::Error> {
match *self {
Access::GlobalVariable(handle) => {
write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
}
Access::StructMember(handle, index) => {
write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
write!(
writer,
"{}{}",
op,
&names[&NameKey::StructMember(handle, index)]
)
}
Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
Access::Array(depth) => write!(writer, "{}{}[__i{depth}]", op, WRAPPED_ARRAY_FIELD),
}
}
}
Expand Down Expand Up @@ -8094,12 +8157,33 @@ mod workgroup_mem_init {
&self,
writer: &mut W,
names: &FastHashMap<NameKey, String>,
module: &crate::Module,
) -> Result<(), core::fmt::Error> {
for next in self.stack.iter() {
next.write(writer, names)?;
for (i, next) in self.stack.iter().enumerate() {
let op = if i == 0 {
// root item doesn't get an operator prefix
""
} else {
// check if the previous item is a pointer to determine the operator for this item
let prev = &self.stack[i - 1];
if prev.is_pointer_type(module) {
"->"
} else {
"."
}
};
next.write(writer, names, op)?;
}
Ok(())
}

fn root_is_workgroup_pointer(&self, module: &crate::Module) -> bool {
if let Some(&Access::GlobalVariable(handle)) = self.stack.first() {
let var = &module.global_variables[handle];
return var.space == crate::AddressSpace::WorkGroup;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know the deal

}
false
}
}

impl<W: Write> Writer<W> {
Expand Down Expand Up @@ -8174,16 +8258,25 @@ mod workgroup_mem_init {
) -> BackendResult {
if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
write!(self.out, "{level}")?;
access_stack.write(&mut self.out, &self.names)?;
// workgroup variables are always pointers; add * to dereference at root level.
// Nested accesses use -> operator from the access stack.
let is_root_workgroup = access_stack.root_is_workgroup_pointer(module);
let is_nested = access_stack.stack.len() > 1;
if is_root_workgroup && !is_nested {
write!(self.out, "*")?;
}
access_stack.write(&mut self.out, &self.names, module)?;
writeln!(self.out, " = {{}};")?;
} else {
match module.types[ty].inner {
crate::TypeInner::Atomic { .. } => {
write!(
self.out,
"{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
)?;
access_stack.write(&mut self.out, &self.names)?;
write!(self.out, "{level}{NAMESPACE}::atomic_store_explicit(")?;
// only skip & for direct access to workgroup atomic
let is_nested = access_stack.stack.len() > 1;
if !access_stack.root_is_workgroup_pointer(module) || is_nested {
write!(self.out, "{ATOMIC_REFERENCE}")?;
}
access_stack.write(&mut self.out, &self.names, module)?;
writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
}
crate::TypeInner::Array { base, size, .. } => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ struct compute1_Input {
kernel void compute1_(
metal::uint3 local_invocation_id [[thread_position_in_threadgroup]]
, uint local_invocation_index [[thread_index_in_threadgroup]]
, threadgroup uint& wg_var
, threadgroup uint* wg_var
) {
if (local_invocation_index == 0u) {
wg_var = {};
*wg_var = {};
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
const Input input = { local_invocation_id, local_invocation_index };
wg_var = input.local_invocation_index * 2u;
uint _e6 = wg_var;
wg_var = _e6 + input.local_invocation_id[0];
*wg_var = input.local_invocation_index * 2u;
uint _e6 = *wg_var;
*wg_var = _e6 + input.local_invocation_id[0];
return;
}
8 changes: 4 additions & 4 deletions naga/tests/out/msl/wgsl-abstract-types-operators.metal
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,18 @@ void wgpu_4445_(
}

void wgpu_4435_(
threadgroup type_3& a
threadgroup type_3* a
) {
uint y = a.inner[as_type<int>(as_type<uint>(1) - as_type<uint>(1))];
uint y = a->inner[as_type<int>(as_type<uint>(1) - as_type<uint>(1))];
return;
}

kernel void main_(
uint __local_invocation_index [[thread_index_in_threadgroup]]
, threadgroup type_3& a
, threadgroup type_3* a
) {
if (__local_invocation_index == 0u) {
a = {};
*a = {};
}
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
runtime_values();
Expand Down
Loading