@@ -11,17 +11,17 @@ DEVICE_FORCEINLINE void size_8_fwd_dit(e2f *x) {
1111 // first stage
1212#pragma unroll
1313 for (unsigned i{0 }; i < 4 ; i++) {
14- const e2f tmp = x[i];
15- x[i] = e2f::add (tmp, x[i + 4 ]);
16- x[i + 4 ] = e2f::sub (tmp, x[i + 4 ]);
14+ const e2f tmp = x[i];
15+ x[i] = e2f::add (tmp, x[i + 4 ]);
16+ x[i + 4 ] = e2f::sub (tmp, x[i + 4 ]);
1717 }
1818
1919 // second stage
2020#pragma unroll
2121 for (unsigned i{0 }; i < 2 ; i++) {
22- const e2f tmp = x[i];
23- x[i] = e2f::add (tmp, x[i + 2 ]);
24- x[i + 2 ] = e2f::sub (tmp, x[i + 2 ]);
22+ const e2f tmp = x[i];
23+ x[i] = e2f::add (tmp, x[i + 2 ]);
24+ x[i + 2 ] = e2f::sub (tmp, x[i + 2 ]);
2525 }
2626 // x[4] = x[4] + W_1_4 * (x[6].real + i * x[6].imag)
2727 // = x[4] + (-i) * (x[6].real + i * x[6].imag)
@@ -31,19 +31,19 @@ DEVICE_FORCEINLINE void size_8_fwd_dit(e2f *x) {
3131 // = x[4] + (-x[6].imag + i * x[6].real)
3232#pragma unroll
3333 for (unsigned i{4 }; i < 6 ; i++) {
34- const e2f tmp0 = x[i];
35- x[i][0 ] = bf::add (x[i][0 ], x[i + 2 ][1 ]);
36- x[i][1 ] = bf::sub (x[i][1 ], x[i + 2 ][0 ]);
37- const bf tmp1 = x[i + 2 ][0 ];
38- x[i + 2 ][0 ] = bf::sub (tmp0[0 ], x[i + 2 ][1 ]);
39- x[i + 2 ][1 ] = bf::add (tmp0[1 ], tmp1);
34+ const e2f tmp0 = x[i];
35+ x[i][0 ] = bf::add (x[i][0 ], x[i + 2 ][1 ]);
36+ x[i][1 ] = bf::sub (x[i][1 ], x[i + 2 ][0 ]);
37+ const bf tmp1 = x[i + 2 ][0 ];
38+ x[i + 2 ][0 ] = bf::sub (tmp0[0 ], x[i + 2 ][1 ]);
39+ x[i + 2 ][1 ] = bf::add (tmp0[1 ], tmp1);
4040 }
4141
4242 // third stage
4343 {
4444 // x[3] = W_1_4 * x[3]
4545 // = -i * (x[3].real + i * x[3].imag)
46- // = x[3].imag - i * x[3].real)
46+ // = x[3].imag - i * x[3].real
4747 const bf tmp = x[3 ][0 ];
4848 x[3 ][0 ] = x[3 ][1 ];
4949 x[3 ][1 ] = bf::neg (tmp);
@@ -52,9 +52,9 @@ DEVICE_FORCEINLINE void size_8_fwd_dit(e2f *x) {
5252 x[7 ] = e2f::mul (W_3_8, x[7 ]); // don't bother optimizing, marginal gains
5353#pragma unroll
5454 for (unsigned i{0 }; i < 8 ; i += 2 ) {
55- const e2f tmp = x[i];
56- x[i] = e2f::add (tmp, x[i + 1 ]);
57- x[i + 1 ] = e2f::sub (tmp, x[i + 1 ]);
55+ const e2f tmp = x[i];
56+ x[i] = e2f::add (tmp, x[i + 1 ]);
57+ x[i + 1 ] = e2f::sub (tmp, x[i + 1 ]);
5858 }
5959
6060 // undo bitrev
@@ -74,17 +74,17 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
7474 // first stage
7575#pragma unroll
7676 for (unsigned i{0 }; i < 4 ; i++) {
77- const e2f tmp = x[i];
78- x[i] = e2f::add (tmp, x[i + 4 ]);
79- x[i + 4 ] = e2f::sub (tmp, x[i + 4 ]);
77+ const e2f tmp = x[i];
78+ x[i] = e2f::add (tmp, x[i + 4 ]);
79+ x[i + 4 ] = e2f::sub (tmp, x[i + 4 ]);
8080 }
8181
8282 // second stage
8383#pragma unroll
8484 for (unsigned i{0 }; i < 2 ; i++) {
85- const e2f tmp = x[i];
86- x[i] = e2f::add (tmp, x[i + 2 ]);
87- x[i + 2 ] = e2f::sub (tmp, x[i + 2 ]);
85+ const e2f tmp = x[i];
86+ x[i] = e2f::add (tmp, x[i + 2 ]);
87+ x[i + 2 ] = e2f::sub (tmp, x[i + 2 ]);
8888 }
8989 // x[4] = x[4] + W_1_4_INV * (x[6].real + i * x[6].imag)
9090 // = x[4] + i * (x[6].real + i * x[6].imag)
@@ -94,19 +94,19 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
9494 // = x[4] + (x[6].imag - i * x[6].real)
9595#pragma unroll
9696 for (unsigned i{4 }; i < 6 ; i++) {
97- const e2f tmp0 = x[i];
98- x[i][0 ] = bf::sub (x[i][0 ], x[i + 2 ][1 ]);
99- x[i][1 ] = bf::add (x[i][1 ], x[i + 2 ][0 ]);
100- const bf tmp1 = x[i + 2 ][0 ];
101- x[i + 2 ][0 ] = bf::add (tmp0[0 ], x[i + 2 ][1 ]);
102- x[i + 2 ][1 ] = bf::sub (tmp0[1 ], tmp1);
97+ const e2f tmp0 = x[i];
98+ x[i][0 ] = bf::sub (x[i][0 ], x[i + 2 ][1 ]);
99+ x[i][1 ] = bf::add (x[i][1 ], x[i + 2 ][0 ]);
100+ const bf tmp1 = x[i + 2 ][0 ];
101+ x[i + 2 ][0 ] = bf::add (tmp0[0 ], x[i + 2 ][1 ]);
102+ x[i + 2 ][1 ] = bf::sub (tmp0[1 ], tmp1);
103103 }
104104
105105 // third stage
106106 {
107107 // x[3] = W_1_4_INV * x[3]
108108 // = i * (x[3].real + i * x[3].imag)
109- // = -x[3].imag + i * x[3].real)
109+ // = -x[3].imag + i * x[3].real)
110110 const bf tmp = x[3 ][0 ];
111111 x[3 ][0 ] = bf::neg (x[3 ][1 ]);
112112 x[3 ][1 ] = tmp;
@@ -115,9 +115,9 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
115115 x[7 ] = e2f::mul (W_3_8_INV, x[7 ]); // don't bother optimizing, marginal gains
116116#pragma unroll
117117 for (unsigned i{0 }; i < 8 ; i += 2 ) {
118- const e2f tmp = x[i];
119- x[i] = e2f::add (tmp, x[i + 1 ]);
120- x[i + 1 ] = e2f::sub (tmp, x[i + 1 ]);
118+ const e2f tmp = x[i];
119+ x[i] = e2f::add (tmp, x[i + 1 ]);
120+ x[i + 1 ] = e2f::sub (tmp, x[i + 1 ]);
121121 }
122122
123123 // undo bitrev
@@ -129,8 +129,7 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
129129 x[6 ] = tmp1;
130130}
131131
132- template <unsigned LOG_RADIX>
133- DEVICE_FORCEINLINE unsigned bitrev_by_radix (const unsigned idx, const unsigned bit_chunks) {
132+ template <unsigned LOG_RADIX> DEVICE_FORCEINLINE unsigned bitrev_by_radix (const unsigned idx, const unsigned bit_chunks) {
134133 constexpr unsigned RADIX_MASK = (1 << LOG_RADIX) - 1 ;
135134 unsigned out{0 }, tmp_idx{idx};
136135 for (unsigned i{0 }; i < bit_chunks; i++) {
@@ -152,7 +151,7 @@ DEVICE_FORCEINLINE void apply_twiddles_same_region(e2f *vals0, e2f *vals1, const
152151 const auto twiddle = get_twiddle_with_direct_index<true >(v * i * twiddle_stride);
153152 vals0[i] = e2f::mul (vals0[i], twiddle);
154153 vals1[i] = e2f::mul (vals1[i], twiddle);
155- }
154+ }
156155 }
157156}
158157
@@ -166,15 +165,15 @@ DEVICE_FORCEINLINE void apply_twiddles_distinct_regions(e2f *vals0, e2f *vals1,
166165 for (unsigned i{1 }; i < RADIX; i++) {
167166 const auto twiddle = get_twiddle_with_direct_index<true >(v * i * twiddle_stride);
168167 vals0[i] = e2f::mul (vals0[i], twiddle);
169- }
168+ }
170169 }
171170 // exchg_region_1 should never be 0
172171 const unsigned v = bitrev_by_radix<LOG_RADIX>(exchg_region_1, idx_bit_chunks);
173172#pragma unroll
174173 for (unsigned i{1 }; i < RADIX; i++) {
175174 const auto twiddle = get_twiddle_with_direct_index<true >(v * i * twiddle_stride);
176175 vals1[i] = e2f::mul (vals1[i], twiddle);
177- }
176+ }
178177}
179178
180- } // namespace airbender::ntt1
179+ } // namespace airbender::ntt
0 commit comments