Skip to content

Commit 107bce8

Browse files
author
mcarilli
committed
clean up geometry calculation + format
1 parent 8305019 commit 107bce8

File tree

5 files changed

+111
-91
lines changed

5 files changed

+111
-91
lines changed

gpu_prover/native/ntt/natural_evals_to_bitrev_Z_radix_8.cu

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ namespace airbender::ntt {
55
EXTERN __launch_bounds__(256, 3) __global__
66
void ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp(vectorized_e2_matrix_getter<ld_modifier::cg> gmem_in,
77
vectorized_e2_matrix_setter<st_modifier::cg> gmem_out, const unsigned start_stage,
8-
unsigned exchg_region_bit_chunks, const unsigned log_n, const unsigned grid_offset) {
8+
unsigned exchg_region_bit_chunks, const unsigned log_exchg_region_size,
9+
const unsigned tile_gmem_stride, const unsigned log_n, const unsigned grid_offset) {
910
constexpr unsigned WARP_SIZE = 32u;
1011
constexpr unsigned LOG_RADIX = 3u;
1112
constexpr unsigned RADIX = 1 << LOG_RADIX;
@@ -25,12 +26,8 @@ EXTERN __launch_bounds__(256, 3) __global__
2526
const unsigned lane_id = threadIdx.x & 31;
2627
const unsigned tile_id = lane_id >> LOG_TILE_SIZE;
2728
const unsigned lane_in_tile = lane_id & TILE_MASK;
28-
const unsigned log_exchg_region_size = log_n - start_stage * LOG_RADIX;
29-
const unsigned log_tile_gmem_stride = log_exchg_region_size - 2 * LOG_RADIX;
30-
const unsigned log_blocks_per_exchg_region = log_tile_gmem_stride - LOG_TILE_SIZE - LOG_WARPS_PER_BLOCK;
31-
const unsigned tile_gmem_stride = 1 << log_tile_gmem_stride;
32-
const unsigned block_exchg_region = effective_block_idx_x >> log_blocks_per_exchg_region;
33-
const unsigned block_in_exchg_region = effective_block_idx_x & ((1 << log_blocks_per_exchg_region) - 1);
29+
const unsigned block_exchg_region = blockIdx.x; // effective_block_idx_x >> log_blocks_per_exchg_region;
30+
const unsigned block_in_exchg_region = blockIdx.y; // effective_block_idx_x & ((1 << log_blocks_per_exchg_region) - 1);
3431
const unsigned gmem_block_offset = block_exchg_region << log_exchg_region_size;
3532
const unsigned gmem_warp_offset = ((block_in_exchg_region << LOG_WARPS_PER_BLOCK) + warp_id) << LOG_TILE_SIZE;
3633
gmem_in.add_row(gmem_block_offset + gmem_warp_offset);
@@ -61,11 +58,11 @@ EXTERN __launch_bounds__(256, 3) __global__
6158
smem[addr] = vals0[i];
6259
smem[addr + WARP_SIZE] = vals1[i];
6360
}
64-
// #pragma unroll
65-
// for (unsigned i{0}, addr{thread_offset}; i < RADIX; i++, addr += RADIX * tile_gmem_stride) {
66-
// gmem_out.set_at_row(addr, vals0[i]);
67-
// gmem_out.set_at_row(addr + TILES_PER_WARP * tile_gmem_stride, vals1[i]);
68-
// }
61+
// #pragma unroll
62+
// for (unsigned i{0}, addr{thread_offset}; i < RADIX; i++, addr += RADIX * tile_gmem_stride) {
63+
// gmem_out.set_at_row(addr, vals0[i]);
64+
// gmem_out.set_at_row(addr + TILES_PER_WARP * tile_gmem_stride, vals1[i]);
65+
// }
6966

7067
__syncwarp();
7168
}
@@ -89,7 +86,7 @@ EXTERN __launch_bounds__(256, 3) __global__
8986

9087
const unsigned gmem_write_offset = lane_in_tile + tile_id * 2 * RADIX * tile_gmem_stride;
9188
#pragma unroll
92-
for (unsigned i{0}, addr{gmem_write_offset}; i < RADIX; i++, addr += tile_gmem_stride ) {
89+
for (unsigned i{0}, addr{gmem_write_offset}; i < RADIX; i++, addr += tile_gmem_stride) {
9390
gmem_out.set_at_row(addr, vals0[i]);
9491
gmem_out.set_at_row(addr + RADIX * tile_gmem_stride, vals1[i]);
9592
}
@@ -189,7 +186,7 @@ EXTERN __launch_bounds__(256, 3) __global__
189186
}
190187

191188
warp_exchg_region_offset *= RADIX;
192-
const unsigned exchg_region_0 = warp_exchg_region_offset + tile_id * 2;
189+
const unsigned exchg_region_0 = warp_exchg_region_offset + tile_id * 2;
193190
const unsigned exchg_region_1 = exchg_region_0 + 1;
194191
twiddle_stride >>= LOG_RADIX;
195192
apply_twiddles_distinct_regions<LOG_RADIX>(vals0, vals1, exchg_region_0, exchg_region_1, twiddle_stride, ++exchg_region_bit_chunks);
@@ -218,7 +215,7 @@ EXTERN __launch_bounds__(256, 3) __global__
218215
}
219216

220217
warp_exchg_region_offset *= RADIX;
221-
const unsigned exchg_region_0 = warp_exchg_region_offset + lane_id;
218+
const unsigned exchg_region_0 = warp_exchg_region_offset + lane_id;
222219
const unsigned exchg_region_1 = exchg_region_0 + 32;
223220
twiddle_stride >>= LOG_RADIX;
224221
apply_twiddles_distinct_regions<LOG_RADIX>(vals0, vals1, exchg_region_0, exchg_region_1, twiddle_stride, ++exchg_region_bit_chunks);

gpu_prover/native/ntt/radix_8_utils.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ namespace airbender::ntt {
44

55
EXTERN __launch_bounds__(128, 8) __global__
66
void ab_bit_reverse_by_radix_8(vectorized_e2_matrix_getter<ld_modifier::cg> src, vectorized_e2_matrix_setter<st_modifier::cg> dst,
7-
const unsigned bit_chunks, const unsigned log_n) {
7+
const unsigned bit_chunks, const unsigned log_n) {
88
const unsigned n = 1 << log_n;
99
const unsigned l_index = blockIdx.x * blockDim.x + threadIdx.x;
1010
if (l_index >= n)

gpu_prover/native/ntt/radix_8_utils.cuh

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@ DEVICE_FORCEINLINE void size_8_fwd_dit(e2f *x) {
1111
// first stage
1212
#pragma unroll
1313
for (unsigned i{0}; i < 4; i++) {
14-
const e2f tmp = x[i];
15-
x[i] = e2f::add(tmp, x[i + 4]);
16-
x[i + 4] = e2f::sub(tmp, x[i + 4]);
14+
const e2f tmp = x[i];
15+
x[i] = e2f::add(tmp, x[i + 4]);
16+
x[i + 4] = e2f::sub(tmp, x[i + 4]);
1717
}
1818

1919
// second stage
2020
#pragma unroll
2121
for (unsigned i{0}; i < 2; i++) {
22-
const e2f tmp = x[i];
23-
x[i] = e2f::add(tmp, x[i + 2]);
24-
x[i + 2] = e2f::sub(tmp, x[i + 2]);
22+
const e2f tmp = x[i];
23+
x[i] = e2f::add(tmp, x[i + 2]);
24+
x[i + 2] = e2f::sub(tmp, x[i + 2]);
2525
}
2626
// x[4] = x[4] + W_1_4 * (x[6].real + i * x[6].imag)
2727
// = x[4] + (-i) * (x[6].real + i * x[6].imag)
@@ -31,19 +31,19 @@ DEVICE_FORCEINLINE void size_8_fwd_dit(e2f *x) {
3131
// = x[4] + (-x[6].imag + i * x[6].real)
3232
#pragma unroll
3333
for (unsigned i{4}; i < 6; i++) {
34-
const e2f tmp0 = x[i];
35-
x[i][0] = bf::add(x[i][0], x[i + 2][1]);
36-
x[i][1] = bf::sub(x[i][1], x[i + 2][0]);
37-
const bf tmp1 = x[i + 2][0];
38-
x[i + 2][0] = bf::sub(tmp0[0], x[i + 2][1]);
39-
x[i + 2][1] = bf::add(tmp0[1], tmp1);
34+
const e2f tmp0 = x[i];
35+
x[i][0] = bf::add(x[i][0], x[i + 2][1]);
36+
x[i][1] = bf::sub(x[i][1], x[i + 2][0]);
37+
const bf tmp1 = x[i + 2][0];
38+
x[i + 2][0] = bf::sub(tmp0[0], x[i + 2][1]);
39+
x[i + 2][1] = bf::add(tmp0[1], tmp1);
4040
}
4141

4242
// third stage
4343
{
4444
// x[3] = W_1_4 * x[3]
4545
// = -i * (x[3].real + i * x[3].imag)
46-
// = x[3].imag - i * x[3].real)
46+
// = x[3].imag - i * x[3].real
4747
const bf tmp = x[3][0];
4848
x[3][0] = x[3][1];
4949
x[3][1] = bf::neg(tmp);
@@ -52,9 +52,9 @@ DEVICE_FORCEINLINE void size_8_fwd_dit(e2f *x) {
5252
x[7] = e2f::mul(W_3_8, x[7]); // don't bother optimizing, marginal gains
5353
#pragma unroll
5454
for (unsigned i{0}; i < 8; i += 2) {
55-
const e2f tmp = x[i];
56-
x[i] = e2f::add(tmp, x[i + 1]);
57-
x[i + 1] = e2f::sub(tmp, x[i + 1]);
55+
const e2f tmp = x[i];
56+
x[i] = e2f::add(tmp, x[i + 1]);
57+
x[i + 1] = e2f::sub(tmp, x[i + 1]);
5858
}
5959

6060
// undo bitrev
@@ -74,17 +74,17 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
7474
// first stage
7575
#pragma unroll
7676
for (unsigned i{0}; i < 4; i++) {
77-
const e2f tmp = x[i];
78-
x[i] = e2f::add(tmp, x[i + 4]);
79-
x[i + 4] = e2f::sub(tmp, x[i + 4]);
77+
const e2f tmp = x[i];
78+
x[i] = e2f::add(tmp, x[i + 4]);
79+
x[i + 4] = e2f::sub(tmp, x[i + 4]);
8080
}
8181

8282
// second stage
8383
#pragma unroll
8484
for (unsigned i{0}; i < 2; i++) {
85-
const e2f tmp = x[i];
86-
x[i] = e2f::add(tmp, x[i + 2]);
87-
x[i + 2] = e2f::sub(tmp, x[i + 2]);
85+
const e2f tmp = x[i];
86+
x[i] = e2f::add(tmp, x[i + 2]);
87+
x[i + 2] = e2f::sub(tmp, x[i + 2]);
8888
}
8989
// x[4] = x[4] + W_1_4_INV * (x[6].real + i * x[6].imag)
9090
// = x[4] + i * (x[6].real + i * x[6].imag)
@@ -94,19 +94,19 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
9494
// = x[4] + (x[6].imag - i * x[6].real)
9595
#pragma unroll
9696
for (unsigned i{4}; i < 6; i++) {
97-
const e2f tmp0 = x[i];
98-
x[i][0] = bf::sub(x[i][0], x[i + 2][1]);
99-
x[i][1] = bf::add(x[i][1], x[i + 2][0]);
100-
const bf tmp1 = x[i + 2][0];
101-
x[i + 2][0] = bf::add(tmp0[0], x[i + 2][1]);
102-
x[i + 2][1] = bf::sub(tmp0[1], tmp1);
97+
const e2f tmp0 = x[i];
98+
x[i][0] = bf::sub(x[i][0], x[i + 2][1]);
99+
x[i][1] = bf::add(x[i][1], x[i + 2][0]);
100+
const bf tmp1 = x[i + 2][0];
101+
x[i + 2][0] = bf::add(tmp0[0], x[i + 2][1]);
102+
x[i + 2][1] = bf::sub(tmp0[1], tmp1);
103103
}
104104

105105
// third stage
106106
{
107107
// x[3] = W_1_4_INV * x[3]
108108
// = i * (x[3].real + i * x[3].imag)
109-
// = -x[3].imag + i * x[3].real)
109+
// = -x[3].imag + i * x[3].real)
110110
const bf tmp = x[3][0];
111111
x[3][0] = bf::neg(x[3][1]);
112112
x[3][1] = tmp;
@@ -115,9 +115,9 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
115115
x[7] = e2f::mul(W_3_8_INV, x[7]); // don't bother optimizing, marginal gains
116116
#pragma unroll
117117
for (unsigned i{0}; i < 8; i += 2) {
118-
const e2f tmp = x[i];
119-
x[i] = e2f::add(tmp, x[i + 1]);
120-
x[i + 1] = e2f::sub(tmp, x[i + 1]);
118+
const e2f tmp = x[i];
119+
x[i] = e2f::add(tmp, x[i + 1]);
120+
x[i + 1] = e2f::sub(tmp, x[i + 1]);
121121
}
122122

123123
// undo bitrev
@@ -129,8 +129,7 @@ DEVICE_FORCEINLINE void size_8_inv_dit(e2f *x) {
129129
x[6] = tmp1;
130130
}
131131

132-
template <unsigned LOG_RADIX>
133-
DEVICE_FORCEINLINE unsigned bitrev_by_radix(const unsigned idx, const unsigned bit_chunks) {
132+
template <unsigned LOG_RADIX> DEVICE_FORCEINLINE unsigned bitrev_by_radix(const unsigned idx, const unsigned bit_chunks) {
134133
constexpr unsigned RADIX_MASK = (1 << LOG_RADIX) - 1;
135134
unsigned out{0}, tmp_idx{idx};
136135
for (unsigned i{0}; i < bit_chunks; i++) {
@@ -152,7 +151,7 @@ DEVICE_FORCEINLINE void apply_twiddles_same_region(e2f *vals0, e2f *vals1, const
152151
const auto twiddle = get_twiddle_with_direct_index<true>(v * i * twiddle_stride);
153152
vals0[i] = e2f::mul(vals0[i], twiddle);
154153
vals1[i] = e2f::mul(vals1[i], twiddle);
155-
}
154+
}
156155
}
157156
}
158157

@@ -166,15 +165,15 @@ DEVICE_FORCEINLINE void apply_twiddles_distinct_regions(e2f *vals0, e2f *vals1,
166165
for (unsigned i{1}; i < RADIX; i++) {
167166
const auto twiddle = get_twiddle_with_direct_index<true>(v * i * twiddle_stride);
168167
vals0[i] = e2f::mul(vals0[i], twiddle);
169-
}
168+
}
170169
}
171170
// exchg_region_1 should never be 0
172171
const unsigned v = bitrev_by_radix<LOG_RADIX>(exchg_region_1, idx_bit_chunks);
173172
#pragma unroll
174173
for (unsigned i{1}; i < RADIX; i++) {
175174
const auto twiddle = get_twiddle_with_direct_index<true>(v * i * twiddle_stride);
176175
vals1[i] = e2f::mul(vals1[i], twiddle);
177-
}
176+
}
178177
}
179178

180-
} // namespace airbender::ntt1
179+
} // namespace airbender::ntt

gpu_prover/src/ntt/mod.rs

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pub mod tests;
88
use era_cudart::cuda_kernel;
99
use era_cudart::error::get_last_error;
1010
use era_cudart::event::{CudaEvent, CudaEventCreateFlags};
11-
use era_cudart::execution::{CudaLaunchConfig, KernelFunction};
11+
use era_cudart::execution::{CudaLaunchConfig, Dim3, KernelFunction};
1212
use era_cudart::result::{CudaResult, CudaResultWrap};
1313
use era_cudart::slice::DeviceSlice;
1414
use era_cudart::stream::{CudaStream, CudaStreamWaitEventFlags};
@@ -115,6 +115,21 @@ n2b_multi_stage_kernel!(ab_compressed_coset_evals_to_Z_final_7_stages_warp);
115115
n2b_multi_stage_kernel!(ab_compressed_coset_evals_to_Z_final_8_stages_warp);
116116
n2b_multi_stage_kernel!(ab_compressed_coset_evals_to_Z_final_9_to_12_stages_block);
117117

118+
cuda_kernel!(
119+
N2BRadix8Nonfinal,
120+
n2b_radix_8_nonfinal_kernel,
121+
inputs_matrix: PtrAndStride<BF>,
122+
outputs_matrix: MutPtrAndStride<BF>,
123+
start_stage: u32,
124+
idx_bit_chunks: u32,
125+
log_exchg_region_size: u32,
126+
tile_gmem_stride: u32,
127+
log_n: u32,
128+
grid_offset: u32,
129+
);
130+
131+
n2b_radix_8_nonfinal_kernel!(ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp);
132+
118133
cuda_kernel!(
119134
N2BRadix8,
120135
n2b_radix_8_kernel,
@@ -126,7 +141,6 @@ cuda_kernel!(
126141
grid_offset: u32,
127142
);
128143

129-
n2b_radix_8_kernel!(ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp);
130144
n2b_radix_8_kernel!(ab_radix_8_main_domain_evals_to_Z_final_12_stages_block);
131145

132146
pub fn bit_reverse_by_radix_8(
@@ -156,6 +170,24 @@ pub fn bit_reverse_by_radix_8(
156170
BitReverseByRadix8Function(ab_bit_reverse_by_radix_8).launch(&config, &args)
157171
}
158172

173+
fn get_noninitial_grid_helpers(log_n: usize, start_stage: usize) -> (usize, usize, Dim3) {
174+
const LOG_RADIX: usize = 3;
175+
const LOG_TILE_SIZE: usize = 3;
176+
const LOG_WARPS_PER_BLOCK: usize = 3;
177+
let log_exchg_region_size = log_n - start_stage * LOG_RADIX;
178+
let log_tile_gmem_stride = log_exchg_region_size - 2 * LOG_RADIX;
179+
let log_blocks_per_exchg_region = log_tile_gmem_stride - LOG_TILE_SIZE - LOG_WARPS_PER_BLOCK;
180+
let tile_gmem_stride = 1 << log_tile_gmem_stride;
181+
let num_exchg_regions = 1 << (log_n - log_exchg_region_size);
182+
let mut block_dims: Dim3 = (num_exchg_regions as u32).into();
183+
block_dims.y = 1 << log_blocks_per_exchg_region as u32;
184+
assert_eq!(
185+
block_dims.x * block_dims.y,
186+
(1 << log_n).get_chunks_count(4096)
187+
);
188+
(log_exchg_region_size, tile_gmem_stride, block_dims)
189+
}
190+
159191
pub fn natural_evals_to_bitrev_Z_radix_8(
160192
inputs_matrix: &(impl DeviceMatrixChunkImpl<BF> + ?Sized),
161193
outputs_matrix: &mut (impl DeviceMatrixChunkMutImpl<BF> + ?Sized),
@@ -174,31 +206,39 @@ pub fn natural_evals_to_bitrev_Z_radix_8(
174206
let outputs_matrix_const = outputs_matrix.as_ptr_and_stride();
175207
let outputs_matrix_mut = outputs_matrix.as_mut_ptr_and_stride();
176208
let threads = 256;
177-
let blocks = n.get_chunks_count(4096);
178-
let config = CudaLaunchConfig::basic(blocks as u32, threads as u32, stream);
179-
let args = N2BRadix8Arguments::new(
209+
let (log_exchg_region_size, tile_gmem_stride, block_dims) =
210+
get_noninitial_grid_helpers(log_n, start_stage);
211+
let config = CudaLaunchConfig::basic(block_dims, threads as u32, stream);
212+
let args = N2BRadix8NonfinalArguments::new(
180213
inputs_matrix,
181214
outputs_matrix_mut,
182215
start_stage as u32,
183216
exchg_region_bit_chunks as u32,
217+
log_exchg_region_size as u32,
218+
tile_gmem_stride as u32,
184219
log_n as u32,
185220
0,
186221
);
187-
N2BRadix8Function(ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp).launch(&config, &args)?;
222+
N2BRadix8NonfinalFunction(ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp)
223+
.launch(&config, &args)?;
188224
start_stage += 2;
189225
exchg_region_bit_chunks += 2;
190226
let threads = 256;
191-
let blocks = n.get_chunks_count(4096);
192-
let config = CudaLaunchConfig::basic(blocks as u32, threads as u32, stream);
193-
let args = N2BRadix8Arguments::new(
227+
let (log_exchg_region_size, tile_gmem_stride, block_dims) =
228+
get_noninitial_grid_helpers(log_n, start_stage);
229+
let config = CudaLaunchConfig::basic(block_dims, threads as u32, stream);
230+
let args = N2BRadix8NonfinalArguments::new(
194231
outputs_matrix_const,
195232
outputs_matrix_mut,
196233
start_stage as u32,
197234
exchg_region_bit_chunks as u32,
235+
log_exchg_region_size as u32,
236+
tile_gmem_stride as u32,
198237
log_n as u32,
199238
0,
200239
);
201-
N2BRadix8Function(ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp).launch(&config, &args)?;
240+
N2BRadix8NonfinalFunction(ab_radix_8_main_domain_evals_to_Z_nonfinal_6_stages_warp)
241+
.launch(&config, &args)?;
202242
start_stage += 2;
203243
exchg_region_bit_chunks += 2;
204244
let threads = 256;
@@ -212,7 +252,8 @@ pub fn natural_evals_to_bitrev_Z_radix_8(
212252
log_n as u32,
213253
0,
214254
);
215-
N2BRadix8Function(ab_radix_8_main_domain_evals_to_Z_final_12_stages_block).launch(&config, &args)
255+
N2BRadix8Function(ab_radix_8_main_domain_evals_to_Z_final_12_stages_block)
256+
.launch(&config, &args)
216257
}
217258

218259
#[allow(clippy::too_many_arguments)]

0 commit comments

Comments
 (0)