@@ -5,50 +5,50 @@ ByteAddressBuffer bias_buffer;
5
5
RWByteAddressBuffer rw_matrix_buffer;
6
6
7
7
void UseCoopVec () {
8
- vector <float , 4 > output_vector;
9
- static const uint is_output_unsigned = 0 ;
10
-
11
- vector <float , 4 > input_vector;
12
- const uint is_input_unsigned = 0 ;
13
- const uint input_interpretation = 9 ; /*F32*/
14
-
15
- const uint matrix_offset = 0 ;
16
- const uint matrix_interpretation = 9 ; /*F32*/
17
- const uint matrix_dimM = 4 ;
18
- const uint matrix_dimK = 4 ;
19
- const uint matrix_layout = 0 ; /*RowMajor*/
20
- const bool matrix_is_transposed = false ;
21
- const uint matrix_stride = 64 ;
22
-
23
- __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
24
- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
25
- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
26
- matrix_is_transposed, matrix_stride);
27
-
28
- const uint bias_offset = 0 ;
29
- const uint bias_interpretation = 9 ; /*F32*/
30
-
31
- __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
32
- is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
33
- matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
34
- matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
35
- bias_interpretation);
36
-
37
- vector <uint , 8 > input_vector1;
38
- vector <uint , 8 > input_vector2;
39
- const uint opa_matrix_offset = 0 ;
40
- const uint opa_matrix_interpretation = 5 ; /*U32*/
41
- const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
42
- const uint opa_matrix_stride = 64 ;
43
-
44
- __builtin_OuterProductAccumulate (input_vector1, input_vector2,
45
- rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
46
- opa_matrix_layout, opa_matrix_stride);
47
-
48
- const uint va_matrix_offset = 0 ;
49
-
50
- __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
51
- va_matrix_offset);
8
+ vector <float , 4 > output_vector;
9
+ static const uint is_output_unsigned = 0 ;
10
+
11
+ vector <float , 4 > input_vector;
12
+ const uint is_input_unsigned = 0 ;
13
+ const uint input_interpretation = 9 ; /*F32*/
14
+
15
+ const uint matrix_offset = 0 ;
16
+ const uint matrix_interpretation = 9 ; /*F32*/
17
+ const uint matrix_dimM = 4 ;
18
+ const uint matrix_dimK = 4 ;
19
+ const uint matrix_layout = 0 ; /*RowMajor*/
20
+ const bool matrix_is_transposed = false ;
21
+ const uint matrix_stride = 64 ;
22
+
23
+ __builtin_MatVecMul (output_vector, is_output_unsigned, input_vector,
24
+ is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
25
+ matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
26
+ matrix_is_transposed, matrix_stride);
27
+
28
+ const uint bias_offset = 0 ;
29
+ const uint bias_interpretation = 9 ; /*F32*/
30
+
31
+ __builtin_MatVecMulAdd (output_vector, is_output_unsigned, input_vector,
32
+ is_input_unsigned, input_interpretation, matrix_buffer, matrix_offset,
33
+ matrix_interpretation, matrix_dimM, matrix_dimK, matrix_layout,
34
+ matrix_is_transposed, matrix_stride, bias_buffer, bias_offset,
35
+ bias_interpretation);
36
+
37
+ vector <uint , 8 > input_vector1;
38
+ vector <uint , 8 > input_vector2;
39
+ const uint opa_matrix_offset = 0 ;
40
+ const uint opa_matrix_interpretation = 5 ; /*U32*/
41
+ const uint opa_matrix_layout = 3 ; /*OuterProductOptimal*/
42
+ const uint opa_matrix_stride = 64 ;
43
+
44
+ __builtin_OuterProductAccumulate (input_vector1, input_vector2,
45
+ rw_matrix_buffer, opa_matrix_offset, opa_matrix_interpretation,
46
+ opa_matrix_layout, opa_matrix_stride);
47
+
48
+ const uint va_matrix_offset = 0 ;
49
+
50
+ __builtin_VectorAccumulate (input_vector1, rw_matrix_buffer,
51
+ va_matrix_offset);
52
52
}
53
53
54
54
// CHECK: define void @ps_main()
@@ -59,7 +59,7 @@ void UseCoopVec() {
59
59
60
60
[Shader ("pixel" )]
61
61
void ps_main ()
62
- {
62
+ {
63
63
UseCoopVec ();
64
64
}
65
65
@@ -72,8 +72,8 @@ void ps_main()
72
72
[Shader ("compute" )]
73
73
[NumThreads (1 ,1 ,1 )]
74
74
void cs_main ()
75
- {
76
- UseCoopVec ();
75
+ {
76
+ UseCoopVec ();
77
77
}
78
78
79
79
// CHECK: define void @vs_main()
@@ -85,11 +85,11 @@ void cs_main()
85
85
[Shader ("vertex" )]
86
86
void vs_main ()
87
87
{
88
- UseCoopVec ();
88
+ UseCoopVec ();
89
89
}
90
90
91
91
struct MyRecord{
92
- uint a;
92
+ uint a;
93
93
};
94
94
95
95
// CHECK: define void @ns_main()
@@ -101,8 +101,8 @@ struct MyRecord{
101
101
[Shader ("node" )]
102
102
[NodeLaunch ("thread" )]
103
103
void ns_main (ThreadNodeInputRecord<MyRecord> input)
104
- {
105
- UseCoopVec ();
104
+ {
105
+ UseCoopVec ();
106
106
}
107
107
108
108
// Vertex shader output structure
@@ -125,7 +125,7 @@ struct GS_OUT {
125
125
[shader ("geometry" )]
126
126
[maxvertexcount (3 )]
127
127
void gs_main (point VS_OUT input[1 ],
128
- inout TriangleStream <GS_OUT> OutputStream)
128
+ inout TriangleStream <GS_OUT> OutputStream)
129
129
{
130
130
UseCoopVec ();
131
131
}
0 commit comments