Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSL: Implement geometry shaders using object and mesh stages. #2200

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 61 additions & 50 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ struct CLIArguments
bool msl_raw_buffer_tese_input = false;
bool msl_multi_patch_workgroup = false;
bool msl_vertex_for_tessellation = false;
bool msl_for_mesh_pipeline = false;
uint32_t msl_additional_fixed_sample_mask = 0xffffffff;
bool msl_arrayed_subpass_input = false;
uint32_t msl_r32ui_linear_texture_alignment = 4;
Expand Down Expand Up @@ -900,23 +901,24 @@ static void print_help_msl()
"\t[--msl-enable-frag-output-mask <mask>]:\n\t\tOnly selectively enable fragment outputs. Useful if pipeline does not enable fragment output for certain locations, as pipeline creation might otherwise fail.\n"
"\t[--msl-no-clip-distance-user-varying]:\n\t\tDo not emit user varyings to emulate gl_ClipDistance in fragment shaders.\n"
"\t[--msl-add-shader-input <index> <format> <size> <rate>]:\n\t\tSpecify the format of the shader input at <index>.\n"
"\t\t<format> can be 'any32', 'any16', 'u16', 'u8', or 'other', to indicate a 32-bit opaque value, 16-bit opaque value, 16-bit unsigned integer, 8-bit unsigned integer, "
"or other-typed variable. <size> is the vector length of the variable, which must be greater than or equal to that declared in the shader. <rate> can be 'vertex', "
"\t\t<format> can be 'i32', 'i16', 'i8', 'u32', 'u16', 'u8', 'float', 'half', or 'other',\n\t\tto indicate a 32/16/8-bit integer (i) or unsigned integer (u), floating point, half-precision floating point, "
"or other-typed variable.\n\t\t'any16' or 'any32' can also be used to specify opaque 16-bit or 32-bit value.\n\t\t<size> is the vector length of the variable, which must be greater than or equal to that declared in the shader. <rate> can be 'vertex', "
"'primitive', or 'patch' to indicate a per-vertex, per-primitive, or per-patch variable.\n"
"\t\tUseful if shader stage interfaces don't match up, as pipeline creation might otherwise fail.\n"
"\t[--msl-add-shader-output <index> <format> <size> <rate>]:\n\t\tSpecify the format of the shader output at <index>.\n"
"\t\t<format> can be 'any32', 'any16', 'u16', 'u8', or 'other', to indicate a 32-bit opaque value, 16-bit opaque value, 16-bit unsigned integer, 8-bit unsigned integer, "
"or other-typed variable. <size> is the vector length of the variable, which must be greater than or equal to that declared in the shader. <rate> can be 'vertex', "
"\t\t<format> can be 'i32', 'i16', 'i8', 'u32', 'u16', 'u8', 'float', 'half', or 'other',\n\t\tto indicate a 32/16/8-bit integer (i) or unsigned integer (u), floating point, half-precision floating point, "
"or other-typed variable.\n\t\t'any16' or 'any32' can also be used to specify opaque 16-bit or 32-bit value.\n\t\t<size> is the vector length of the variable, which must be greater than or equal to that declared in the shader. <rate> can be 'vertex', "
"'primitive', or 'patch' to indicate a per-vertex, per-primitive, or per-patch variable.\n"
"\t\tUseful if shader stage interfaces don't match up, as pipeline creation might otherwise fail.\n"
"\t[--msl-shader-input <index> <format> <size>]:\n\t\tSpecify the format of the shader input at <index>.\n"
"\t\t<format> can be 'any32', 'any16', 'u16', 'u8', or 'other', to indicate a 32-bit opaque value, 16-bit opaque value, 16-bit unsigned integer, 8-bit unsigned integer, "
"or other-typed variable. <size> is the vector length of the variable, which must be greater than or equal to that declared in the shader."
"\t\tEquivalent to --msl-add-shader-input with a rate of 'vertex'.\n"
"\t[--msl-shader-attribute <index> <format> <size> <offset> <stride> <binding>]:\n\t\tSpecify the vertex attribute at <index>.\n"
"\t\t<format> can be 'i32', 'i16', 'i8', 'u32', 'u16', 'u8', 'float', 'half', or 'other',\n\t\tto indicate a 32/16/8-bit integer (i) or unsigned integer (u), floating point, half-precision floating point, "
"or other-typed variable.\n\t\t'any16' or 'any32' can also be used to specify opaque 16-bit or 32-bit value.\n\t\t<size> is the vector length of the variable, which must be greater than or equal to that declared in the shader."
"\n\t\tEquivalent to --msl-add-shader-input with a rate of 'vertex'.\n"
"\t[--msl-shader-output <index> <format> <size>]:\n\t\tSpecify the format of the shader output at <index>.\n"
"\t\t<format> can be 'any32', 'any16', 'u16', 'u8', or 'other', to indicate a 32-bit opaque value, 16-bit opaque value, 16-bit unsigned integer, 8-bit unsigned integer, "
"or other-typed variable. <size> is the vector length of the variable, which must be greater than or equal to that declared in the shader."
"\t\tEquivalent to --msl-add-shader-output with a rate of 'vertex'.\n"
"\t\t<format> can be 'i32', 'i16', 'i8', 'u32', 'u16', 'u8', 'float', 'half', or 'other',\n\t\tto indicate a 32/16/8-bit integer (i) or unsigned integer (u), floating point, half-precision floating point, "
"or other-typed variable.\n\t\t'any16' or 'any32' can also be used to specify opaque 16-bit or 32-bit value.\n\t\t<size> is the vector length of the variable, which must be greater than or equal to that declared in the shader."
"\n\t\tEquivalent to --msl-add-shader-output with a rate of 'vertex'.\n"
"\t[--msl-raw-buffer-tese-input]:\n\t\tUse raw buffers for tessellation evaluation input.\n"
"\t\tThis allows the use of nested structures and arrays.\n"
"\t\tIn a future version of SPIRV-Cross, this will become the default.\n"
Expand All @@ -926,6 +928,7 @@ static void print_help_msl()
"\t\tIn a future version of SPIRV-Cross, this will become the default.\n"
"\t[--msl-vertex-for-tessellation]:\n\t\tWhen handling a vertex shader, marks it as one that will be used with a new-style tessellation control shader.\n"
"\t\tThe vertex shader is output to MSL as a compute kernel which outputs vertices to the buffer in the order they are received, rather than in index order as with --msl-capture-output normally.\n"
"\t[--msl-for-mesh-pipeline]:\n\t\tWhen handling a vertex shader, marks it as one that will be used in a mesh pipeline in conjunction with a geometry shader.\n"
"\t[--msl-additional-fixed-sample-mask <mask>]:\n"
"\t\tSet an additional fixed sample mask. If the shader outputs a sample mask, then the final sample mask will be a bitwise AND of the two.\n"
"\t[--msl-arrayed-subpass-input]:\n\t\tAssume that images of dimension SubpassData have multiple layers. Layered input attachments are accessed relative to BuiltInLayer.\n"
Expand Down Expand Up @@ -1219,6 +1222,7 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
msl_opts.raw_buffer_tese_input = args.msl_raw_buffer_tese_input;
msl_opts.multi_patch_workgroup = args.msl_multi_patch_workgroup;
msl_opts.vertex_for_tessellation = args.msl_vertex_for_tessellation;
msl_opts.for_mesh_pipeline = args.msl_for_mesh_pipeline;
msl_opts.additional_fixed_sample_mask = args.msl_additional_fixed_sample_mask;
msl_opts.arrayed_subpass_input = args.msl_arrayed_subpass_input;
msl_opts.r32ui_linear_texture_alignment = args.msl_r32ui_linear_texture_alignment;
Expand Down Expand Up @@ -1554,6 +1558,34 @@ static string compile_iteration(const CLIArguments &args, std::vector<uint32_t>
return ret;
}

static MSLShaderVariableFormat parse_format(const char *text)
{
MSLShaderVariableFormat format;
if (strcmp(text, "i8") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_INT8;
else if (strcmp(text, "i16") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_INT16;
else if (strcmp(text, "i32") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_INT32;
else if (strcmp(text, "u8") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_UINT8;
else if (strcmp(text, "u16") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_UINT16;
else if (strcmp(text, "u32") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_UINT32;
else if (strcmp(text, "float") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_FLOAT;
else if (strcmp(text, "half") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_HALF;
else if (strcmp(text, "any16") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_ANY16;
else if (strcmp(text, "any32") == 0)
format = MSL_SHADER_VARIABLE_FORMAT_ANY32;
else
format = MSL_SHADER_VARIABLE_FORMAT_OTHER;
return format;
}

static int main_inner(int argc, char *argv[])
{
CLIArguments args;
Expand Down Expand Up @@ -1685,16 +1717,7 @@ static int main_inner(int argc, char *argv[])
// Make sure next_uint() is called in-order.
input.location = parser.next_uint();
const char *format = parser.next_value_string("other");
if (strcmp(format, "any32") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_ANY32;
else if (strcmp(format, "any16") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_ANY16;
else if (strcmp(format, "u16") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_UINT16;
else if (strcmp(format, "u8") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_UINT8;
else
input.format = MSL_SHADER_VARIABLE_FORMAT_OTHER;
input.format = parse_format(format);
input.vecsize = parser.next_uint();
const char *rate = parser.next_value_string("vertex");
if (strcmp(rate, "primitive") == 0)
Expand All @@ -1710,16 +1733,7 @@ static int main_inner(int argc, char *argv[])
// Make sure next_uint() is called in-order.
output.location = parser.next_uint();
const char *format = parser.next_value_string("other");
if (strcmp(format, "any32") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_ANY32;
else if (strcmp(format, "any16") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_ANY16;
else if (strcmp(format, "u16") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_UINT16;
else if (strcmp(format, "u8") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_UINT8;
else
output.format = MSL_SHADER_VARIABLE_FORMAT_OTHER;
output.format = parse_format(format);
output.vecsize = parser.next_uint();
const char *rate = parser.next_value_string("vertex");
if (strcmp(rate, "primitive") == 0)
Expand All @@ -1730,21 +1744,26 @@ static int main_inner(int argc, char *argv[])
output.rate = MSL_SHADER_VARIABLE_RATE_PER_VERTEX;
args.msl_shader_outputs.push_back(output);
});
cbs.add("--msl-shader-attribute", [&args](CLIParser &parser) {
MSLShaderInterfaceVariable input;
// Make sure next_uint() is called in-order.
input.location = parser.next_uint();
const char *format = parser.next_value_string("other");
input.format = parse_format(format);
input.vecsize = parser.next_uint();

input.offset = parser.next_uint();
input.stride = parser.next_uint();
input.binding = parser.next_uint();

args.msl_shader_inputs.push_back(input);
});
cbs.add("--msl-shader-input", [&args](CLIParser &parser) {
MSLShaderInterfaceVariable input;
// Make sure next_uint() is called in-order.
input.location = parser.next_uint();
const char *format = parser.next_value_string("other");
if (strcmp(format, "any32") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_ANY32;
else if (strcmp(format, "any16") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_ANY16;
else if (strcmp(format, "u16") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_UINT16;
else if (strcmp(format, "u8") == 0)
input.format = MSL_SHADER_VARIABLE_FORMAT_UINT8;
else
input.format = MSL_SHADER_VARIABLE_FORMAT_OTHER;
input.format = parse_format(format);
input.vecsize = parser.next_uint();
args.msl_shader_inputs.push_back(input);
});
Expand All @@ -1753,22 +1772,14 @@ static int main_inner(int argc, char *argv[])
// Make sure next_uint() is called in-order.
output.location = parser.next_uint();
const char *format = parser.next_value_string("other");
if (strcmp(format, "any32") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_ANY32;
else if (strcmp(format, "any16") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_ANY16;
else if (strcmp(format, "u16") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_UINT16;
else if (strcmp(format, "u8") == 0)
output.format = MSL_SHADER_VARIABLE_FORMAT_UINT8;
else
output.format = MSL_SHADER_VARIABLE_FORMAT_OTHER;
output.format = parse_format(format);
output.vecsize = parser.next_uint();
args.msl_shader_outputs.push_back(output);
});
cbs.add("--msl-raw-buffer-tese-input", [&args](CLIParser &) { args.msl_raw_buffer_tese_input = true; });
cbs.add("--msl-multi-patch-workgroup", [&args](CLIParser &) { args.msl_multi_patch_workgroup = true; });
cbs.add("--msl-vertex-for-tessellation", [&args](CLIParser &) { args.msl_vertex_for_tessellation = true; });
cbs.add("--msl-for-mesh-pipeline", [&args](CLIParser &) { args.msl_for_mesh_pipeline = true; });
cbs.add("--msl-additional-fixed-sample-mask",
[&args](CLIParser &parser) { args.msl_additional_fixed_sample_mask = parser.next_hex_uint(); });
cbs.add("--msl-arrayed-subpass-input", [&args](CLIParser &) { args.msl_arrayed_subpass_input = true; });
Expand Down
159 changes: 159 additions & 0 deletions reference/shaders-msl/geom/basic.msl31.geom
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_mesh>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};

template<typename V, typename P, int MaxV, int MaxP, metal::topology T>
struct spvMeshStream
{
using mesh_t = metal::mesh<V, P, MaxV, MaxP, T>;
thread mesh_t &meshOut;
int currentVertex = 0;
int currentIndex = 0;
int currentVertexInPrimitive = 0;
int currentPrimitive = 0;
thread P &primitiveData;
thread V &vertexData;
spvMeshStream(thread mesh_t &_meshOut, thread V &_v, thread P &_p) : meshOut(_meshOut), primitiveData(_p), vertexData(_v)
{
}
~spvMeshStream()
{
meshOut.set_primitive_count(currentPrimitive);
}
int VperP()
{
if (T == metal::topology::triangle) return 3;
else if (T == metal::topology::line) return 2;
else /* if (T == metal::topology::point) */ return 1;
}
void EndPrimitive()
{
currentVertexInPrimitive = 0;
}
void EmitVertex()
{
meshOut.set_vertex(currentVertex++, vertexData);
currentVertexInPrimitive++;
if (currentVertexInPrimitive >= VperP())
{
if (T == metal::topology::triangle) meshOut.set_index(currentIndex++, currentVertex-3);
if (T == metal::topology::triangle || T == metal::topology::line) meshOut.set_index(currentIndex++, currentVertex-2);
meshOut.set_index(currentIndex++, currentVertex-1);
meshOut.set_primitive(currentPrimitive++, primitiveData);
}
}
};
struct VertexData
{
float3 normal;
float4 pos;
};

struct main0_out_2
{
};
struct main0_out_1
{
float3 vNormal [[user(locn0)]];
float4 gl_Position [[position]];
};

struct main0_out_2_1
{
};
struct main0_in
{
spvUnsafeArray<VertexData, 3> vin;
spvUnsafeArray<float4, 3> pos;
};

enum { VERTEX_COUNT = 3, PRIMITIVE_COUNT = 1 };
using mesh_stream_t = spvMeshStream<main0_out_1, main0_out_2_1, VERTEX_COUNT, PRIMITIVE_COUNT, metal::topology::triangle>;
void main0(mesh_stream_t::mesh_t spvMeshOut, main0_in in)
{
main0_out_1 out = {};
main0_out_2_1 out_1 = {};
mesh_stream_t meshStream(spvMeshOut, out, out_1);
out.gl_Position = in.pos[0];
out.vNormal = in.vin[0].normal;
meshStream.EmitVertex();
out.gl_Position = in.pos[1];
out.vNormal = in.vin[1].normal;
meshStream.EmitVertex();
out.gl_Position = in.pos[2];
out.vNormal = in.vin[2].normal;
meshStream.EmitVertex();
meshStream.EndPrimitive();
}

struct Payload
{
struct
{
struct
{
VertexData vin [[user(locn0)]];
float4 pos [[user(locn2)]];
} in;
} vertices[3];
};
[[mesh]] void main0(mesh_stream_t::mesh_t outputMesh, const object_data Payload &payload [[payload]],

uint lid [[thread_index_in_threadgroup]], uint tid [[threadgroup_position_in_grid]])
{
main0_in in;
const unsigned long vertexCount = 3;
for (unsigned long i = 0; i < vertexCount; ++i)
{
auto out = payload.vertices[i];
if (i < sizeof(in.pos) / sizeof(in.pos[0]))
in.pos[i] = out.in.pos;
if (i < sizeof(in.vin) / sizeof(in.vin[0]))
in.vin[i] = out.in.vin;
}
main0(outputMesh, in
);
}
Loading