Skip to content

Commit fbb0d4c

Browse files
committed
MSL: Update reference shaders.
1 parent 4258bbd commit fbb0d4c

7 files changed

Lines changed: 399 additions & 4 deletions
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
#pragma clang diagnostic ignored "-Wmissing-braces"
3+
4+
#include <metal_stdlib>
5+
#include <simd/simd.h>
6+
#include <metal_mesh>
7+
8+
using namespace metal;
9+
10+
template<typename T, size_t Num>
11+
struct spvUnsafeArray
12+
{
13+
T elements[Num ? Num : 1];
14+
15+
thread T& operator [] (size_t pos) thread
16+
{
17+
return elements[pos];
18+
}
19+
constexpr const thread T& operator [] (size_t pos) const thread
20+
{
21+
return elements[pos];
22+
}
23+
24+
device T& operator [] (size_t pos) device
25+
{
26+
return elements[pos];
27+
}
28+
constexpr const device T& operator [] (size_t pos) const device
29+
{
30+
return elements[pos];
31+
}
32+
33+
constexpr const constant T& operator [] (size_t pos) const constant
34+
{
35+
return elements[pos];
36+
}
37+
38+
threadgroup T& operator [] (size_t pos) threadgroup
39+
{
40+
return elements[pos];
41+
}
42+
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
43+
{
44+
return elements[pos];
45+
}
46+
};
47+
48+
template<typename V, typename P, int MaxV, int MaxP, metal::topology T>
49+
struct spvMeshStream
50+
{
51+
using mesh_t = metal::mesh<V, P, MaxV, MaxP, T>;
52+
thread mesh_t &meshOut;
53+
int currentVertex = 0;
54+
int currentIndex = 0;
55+
int currentVertexInPrimitive = 0;
56+
int currentPrimitive = 0;
57+
thread P &primitiveData;
58+
thread V &vertexData;
59+
spvMeshStream(thread mesh_t &_meshOut, thread V &_v, thread P &_p) : meshOut(_meshOut), primitiveData(_p), vertexData(_v)
60+
{
61+
}
62+
~spvMeshStream()
63+
{
64+
meshOut.set_primitive_count(currentPrimitive);
65+
}
66+
int VperP()
67+
{
68+
if (T == metal::topology::triangle) return 3;
69+
else if (T == metal::topology::line) return 2;
70+
else /* if (T == metal::topology::point) */ return 1;
71+
}
72+
void EndPrimitive()
73+
{
74+
currentVertexInPrimitive = 0;
75+
}
76+
void EmitVertex()
77+
{
78+
meshOut.set_vertex(currentVertex++, vertexData);
79+
currentVertexInPrimitive++;
80+
if (currentVertexInPrimitive >= VperP())
81+
{
82+
if (T == metal::topology::triangle) meshOut.set_index(currentIndex++, currentVertex-3);
83+
if (T == metal::topology::triangle || T == metal::topology::line) meshOut.set_index(currentIndex++, currentVertex-2);
84+
meshOut.set_index(currentIndex++, currentVertex-1);
85+
meshOut.set_primitive(currentPrimitive++, primitiveData);
86+
}
87+
}
88+
};
89+
struct VertexData
90+
{
91+
float3 normal;
92+
float4 pos;
93+
};
94+
95+
struct main0_out_2
96+
{
97+
};
98+
struct main0_out_1
99+
{
100+
float3 vNormal [[user(locn0)]];
101+
float4 gl_Position [[position]];
102+
};
103+
104+
struct main0_out_2_1
105+
{
106+
};
107+
struct main0_in
108+
{
109+
spvUnsafeArray<VertexData, 3> vin;
110+
spvUnsafeArray<float4, 3> pos;
111+
};
112+
113+
enum { VERTEX_COUNT = 3, PRIMITIVE_COUNT = 1 };
114+
using mesh_stream_t = spvMeshStream<main0_out_1, main0_out_2_1, VERTEX_COUNT, PRIMITIVE_COUNT, metal::topology::triangle>;
115+
void main0(mesh_stream_t::mesh_t spvMeshOut, main0_in in)
116+
{
117+
main0_out_1 out = {};
118+
main0_out_2_1 out_1 = {};
119+
mesh_stream_t meshStream(spvMeshOut, out, out_1);
120+
out.gl_Position = in.pos[0];
121+
out.vNormal = in.vin[0].normal;
122+
meshStream.EmitVertex();
123+
out.gl_Position = in.pos[1];
124+
out.vNormal = in.vin[1].normal;
125+
meshStream.EmitVertex();
126+
out.gl_Position = in.pos[2];
127+
out.vNormal = in.vin[2].normal;
128+
meshStream.EmitVertex();
129+
meshStream.EndPrimitive();
130+
}
131+
132+
struct Payload
133+
{
134+
struct
135+
{
136+
struct
137+
{
138+
VertexData vin [[user(locn0)]];
139+
float4 pos [[user(locn2)]];
140+
} in;
141+
} vertices[3];
142+
};
143+
[[mesh]] void main0(mesh_stream_t::mesh_t outputMesh, const object_data Payload &payload [[payload]],
144+
145+
uint lid [[thread_index_in_threadgroup]], uint tid [[threadgroup_position_in_grid]])
146+
{
147+
main0_in in;
148+
const unsigned long vertexCount = 3;
149+
for (unsigned long i = 0; i < vertexCount; ++i)
150+
{
151+
auto out = payload.vertices[i];
152+
if (i < sizeof(in.pos) / sizeof(in.pos[0]))
153+
in.pos[i] = out.in.pos;
154+
if (i < sizeof(in.vin) / sizeof(in.vin[0]))
155+
in.vin[i] = out.in.vin;
156+
}
157+
main0(outputMesh, in
158+
);
159+
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#pragma clang diagnostic ignored "-Wmissing-prototypes"
2+
#pragma clang diagnostic ignored "-Wmissing-braces"
3+
4+
#include <metal_stdlib>
5+
#include <simd/simd.h>
6+
#include <metal_mesh>
7+
8+
using namespace metal;
9+
10+
template<typename T, size_t Num>
11+
struct spvUnsafeArray
12+
{
13+
T elements[Num ? Num : 1];
14+
15+
thread T& operator [] (size_t pos) thread
16+
{
17+
return elements[pos];
18+
}
19+
constexpr const thread T& operator [] (size_t pos) const thread
20+
{
21+
return elements[pos];
22+
}
23+
24+
device T& operator [] (size_t pos) device
25+
{
26+
return elements[pos];
27+
}
28+
constexpr const device T& operator [] (size_t pos) const device
29+
{
30+
return elements[pos];
31+
}
32+
33+
constexpr const constant T& operator [] (size_t pos) const constant
34+
{
35+
return elements[pos];
36+
}
37+
38+
threadgroup T& operator [] (size_t pos) threadgroup
39+
{
40+
return elements[pos];
41+
}
42+
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
43+
{
44+
return elements[pos];
45+
}
46+
};
47+
48+
template<typename V, typename P, int MaxV, int MaxP, metal::topology T>
49+
struct spvMeshStream
50+
{
51+
using mesh_t = metal::mesh<V, P, MaxV, MaxP, T>;
52+
thread mesh_t &meshOut;
53+
int currentVertex = 0;
54+
int currentIndex = 0;
55+
int currentVertexInPrimitive = 0;
56+
int currentPrimitive = 0;
57+
thread P &primitiveData;
58+
thread V &vertexData;
59+
spvMeshStream(thread mesh_t &_meshOut, thread V &_v, thread P &_p) : meshOut(_meshOut), primitiveData(_p), vertexData(_v)
60+
{
61+
}
62+
~spvMeshStream()
63+
{
64+
meshOut.set_primitive_count(currentPrimitive);
65+
}
66+
int VperP()
67+
{
68+
if (T == metal::topology::triangle) return 3;
69+
else if (T == metal::topology::line) return 2;
70+
else /* if (T == metal::topology::point) */ return 1;
71+
}
72+
void EndPrimitive()
73+
{
74+
currentVertexInPrimitive = 0;
75+
}
76+
void EmitVertex()
77+
{
78+
meshOut.set_vertex(currentVertex++, vertexData);
79+
currentVertexInPrimitive++;
80+
if (currentVertexInPrimitive >= VperP())
81+
{
82+
if (T == metal::topology::triangle) meshOut.set_index(currentIndex++, currentVertex-3);
83+
if (T == metal::topology::triangle || T == metal::topology::line) meshOut.set_index(currentIndex++, currentVertex-2);
84+
meshOut.set_index(currentIndex++, currentVertex-1);
85+
meshOut.set_primitive(currentPrimitive++, primitiveData);
86+
}
87+
}
88+
};
89+
struct VertexData
90+
{
91+
float3 normal;
92+
float4 position;
93+
};
94+
95+
struct main0_out_2
96+
{
97+
};
98+
struct main0_out_1
99+
{
100+
float3 vNormal [[user(locn0)]];
101+
float4 gl_Position [[position]];
102+
};
103+
104+
struct main0_out_2_1
105+
{
106+
};
107+
struct main0_in
108+
{
109+
spvUnsafeArray<VertexData, 2> vin;
110+
};
111+
112+
enum { VERTEX_COUNT = 2, PRIMITIVE_COUNT = 1 };
113+
using mesh_stream_t = spvMeshStream<main0_out_1, main0_out_2_1, VERTEX_COUNT, PRIMITIVE_COUNT, metal::topology::line>;
114+
void main0(mesh_stream_t::mesh_t spvMeshOut, main0_in in)
115+
{
116+
main0_out_1 out = {};
117+
main0_out_2_1 out_1 = {};
118+
mesh_stream_t meshStream(spvMeshOut, out, out_1);
119+
out.gl_Position = in.vin[0].position;
120+
out.vNormal = in.vin[0].normal;
121+
meshStream.EmitVertex();
122+
out.gl_Position = in.vin[1].position;
123+
out.vNormal = in.vin[1].normal;
124+
meshStream.EmitVertex();
125+
meshStream.EndPrimitive();
126+
}
127+
128+
struct Payload
129+
{
130+
struct
131+
{
132+
struct
133+
{
134+
VertexData vin [[user(locn0)]];
135+
} in;
136+
} vertices[2];
137+
};
138+
[[mesh]] void main0(mesh_stream_t::mesh_t outputMesh, const object_data Payload &payload [[payload]],
139+
140+
uint lid [[thread_index_in_threadgroup]], uint tid [[threadgroup_position_in_grid]])
141+
{
142+
main0_in in;
143+
const unsigned long vertexCount = 2;
144+
for (unsigned long i = 0; i < vertexCount; ++i)
145+
{
146+
auto out = payload.vertices[i];
147+
if (i < sizeof(in.vin) / sizeof(in.vin[0]))
148+
in.vin[i] = out.in.vin;
149+
}
150+
main0(outputMesh, in
151+
);
152+
}

reference/shaders-msl/tesc/arrayed-block-io.multi-patch.tesc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ struct main0_patchOut
6868
struct main0_in
6969
{
7070
float3 in_tc_attr;
71-
ushort2 m_179;
71+
ushort2 m_182;
7272
};
7373

7474
kernel void main0(uint3 gl_GlobalInvocationID [[thread_position_in_grid]], constant uint* spvIndirectParams [[buffer(29)]], device main0_patchOut* spvPatchOut [[buffer(27)]], device MTLQuadTessellationFactorsHalf* spvTessLevel [[buffer(26)]], device main0_in* spvIn [[buffer(22)]])

reference/shaders-msl/tesc/matrix-output.multi-patch.tesc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct main0_out
1313
struct main0_in
1414
{
1515
float3 in_tc_attr;
16-
ushort2 m_103;
16+
ushort2 m_106;
1717
};
1818

1919
kernel void main0(uint3 gl_GlobalInvocationID [[thread_position_in_grid]], device main0_out* spvOut [[buffer(28)]], constant uint* spvIndirectParams [[buffer(29)]], device MTLQuadTessellationFactorsHalf* spvTessLevel [[buffer(26)]], device main0_in* spvIn [[buffer(22)]])

reference/shaders-msl/tesc/struct-output.multi-patch.tesc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct main0_out
2020
struct main0_in
2121
{
2222
float3 in_tc_attr;
23-
ushort2 m_107;
23+
ushort2 m_110;
2424
};
2525

2626
kernel void main0(uint3 gl_GlobalInvocationID [[thread_position_in_grid]], device main0_out* spvOut [[buffer(28)]], constant uint* spvIndirectParams [[buffer(29)]], device MTLQuadTessellationFactorsHalf* spvTessLevel [[buffer(26)]], device main0_in* spvIn [[buffer(22)]])

reference/shaders-msl/tesc/water_tess.multi-patch.tesc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct main0_patchOut
2424
struct main0_in
2525
{
2626
float3 vPatchPosBase;
27-
ushort2 m_430;
27+
ushort2 m_433;
2828
};
2929

3030
static inline __attribute__((always_inline))

0 commit comments

Comments
 (0)