Skip to content

Commit eab7cf3

Browse files
authored
feat(gpu_prover): implement partial tree caching (#185)
## What ❔ This PR implements partial tree caching and recomputation. ## Why ❔ Partial tree caching and recomputation allows to only store bottom part of the tree and recompute only small slice on demand when performing queries. This means that the memory requirement is much lower vs full tree caching while having almost no negative performance impact. ## Is this a breaking change? - [ ] Yes - [x] No ## Checklist - [x] PR title corresponds to the body of PR (we generate changelog entries from PRs). - [x] Tests for the changes have been added / updated. - [ ] Documentation comments have been added / updated. - [x] Code has been formatted.
1 parent 8b236ed commit eab7cf3

File tree

12 files changed

+612
-452
lines changed

12 files changed

+612
-452
lines changed

gpu_prover/native/blake2s.cu

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ typedef uint32_t u32;
99
typedef uint64_t u64;
1010
typedef base_field bf;
1111

12+
#define LOG_WARP_SIZE 5
13+
constexpr unsigned WARP_SIZE = 1 << LOG_WARP_SIZE;
14+
constexpr unsigned WARP_MASK = WARP_SIZE - 1;
15+
16+
#define FULL_MASK 0xffffffff
17+
1218
#define ROTR32(x, y) (((x) >> (y)) ^ ((x) << (32 - (y))))
1319

1420
#define G(a, b, c, d, x, y) \
@@ -159,6 +165,88 @@ EXTERN __global__ void ab_gather_merkle_paths_kernel(const unsigned *indexes, co
159165
results[dst_index] = values[src_index];
160166
}
161167

168+
EXTERN __global__ void ab_gather_rows_and_merkle_paths_kernel(const unsigned *indexes, const unsigned indexes_count, const bool bit_reverse_indexes,
169+
const bf *values, const unsigned log_rows_per_leaf, const unsigned cols_count,
170+
const unsigned log_total_leaves_count, matrix_setter<bf, st_modifier::cs> leaf_values,
171+
const u32 *tree_bottom, const unsigned layers_count, u32 *merkle_paths) {
172+
const unsigned lane_idx = threadIdx.x;
173+
const unsigned idx = blockIdx.x;
174+
const unsigned index_warp = indexes[idx];
175+
const unsigned index_lane = index_warp & ~WARP_MASK | lane_idx;
176+
const bool is_output_lane = index_warp == index_lane;
177+
const unsigned leaf_index = bit_reverse_indexes ? __brev(index_lane) >> (32 - log_total_leaves_count) : index_lane;
178+
values += leaf_index << log_rows_per_leaf;
179+
leaf_values.add_row(idx);
180+
merkle_paths += idx * STATE_SIZE;
181+
const unsigned row_mask = (1u << log_rows_per_leaf) - 1;
182+
auto read = [=](const unsigned offset) {
183+
const unsigned row = offset & row_mask;
184+
const unsigned col = offset >> log_rows_per_leaf;
185+
const auto address = values + row + (col << (log_rows_per_leaf + log_total_leaves_count));
186+
return col < cols_count ? bf::into_canonical_u32(load_cs(address)) : 0;
187+
};
188+
u32 state[STATE_SIZE];
189+
u32 block[BLOCK_SIZE];
190+
initialize(state);
191+
u32 t = 0;
192+
const unsigned values_count = cols_count << log_rows_per_leaf;
193+
unsigned offset = 0;
194+
while (offset < values_count) {
195+
const unsigned remaining = values_count - offset;
196+
const bool is_final_block = remaining <= BLOCK_SIZE;
197+
#pragma unroll
198+
for (unsigned i = 0; i < BLOCK_SIZE; i++, offset++) {
199+
const u32 value = read(offset);
200+
block[i] = value;
201+
if (offset >= values_count)
202+
continue;
203+
if (is_output_lane)
204+
leaf_values.set(bf(value));
205+
leaf_values.inc_col();
206+
}
207+
if (is_final_block)
208+
compress<true>(state, t, block, remaining);
209+
else
210+
compress<false>(state, t, block, BLOCK_SIZE);
211+
}
212+
#pragma unroll
213+
for (unsigned layer = 0; layer < LOG_WARP_SIZE; layer++) {
214+
u32 other_state[STATE_SIZE];
215+
const bool take_other_first = lane_idx >> layer & 1;
216+
#pragma unroll
217+
for (unsigned i = 0; i < STATE_SIZE; i++) {
218+
other_state[i] = __shfl_xor_sync(FULL_MASK, state[i], 1 << layer);
219+
if (is_output_lane)
220+
merkle_paths[i] = other_state[i];
221+
if (take_other_first) {
222+
block[i] = other_state[i];
223+
block[i + STATE_SIZE] = state[i];
224+
} else {
225+
block[i] = state[i];
226+
block[i + STATE_SIZE] = other_state[i];
227+
}
228+
}
229+
initialize(state);
230+
t = 0;
231+
compress<true>(state, t, block, BLOCK_SIZE);
232+
merkle_paths += indexes_count * STATE_SIZE;
233+
}
234+
if (lane_idx >= STATE_SIZE)
235+
return;
236+
unsigned digest_index = index_warp >> LOG_WARP_SIZE;
237+
unsigned log_digests_count = log_total_leaves_count - LOG_WARP_SIZE;
238+
tree_bottom += lane_idx;
239+
merkle_paths += lane_idx;
240+
for (unsigned layer = LOG_WARP_SIZE; layer < layers_count; layer++) {
241+
const unsigned other_index = digest_index ^ 1;
242+
*merkle_paths = *(tree_bottom + other_index * STATE_SIZE);
243+
digest_index >>= 1;
244+
tree_bottom += (1u << log_digests_count) * STATE_SIZE;
245+
log_digests_count--;
246+
merkle_paths += indexes_count * STATE_SIZE;
247+
}
248+
}
249+
162250
EXTERN __global__ void ab_blake2s_pow_kernel(const u64 *seed, const u32 bits_count, const u64 max_nonce, volatile u64 *result) {
163251
const uint32_t digest_mask = 0xffffffff << 32 - bits_count;
164252
__align__(8) u32 m_u32[BLOCK_SIZE] = {};

gpu_prover/src/blake2s.rs

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ cuda_kernel!(
198198
)
199199
);
200200

201-
pub fn gather_merkle_paths_device(
201+
pub fn gather_merkle_paths(
202202
indexes: &DeviceSlice<u32>,
203203
values: &DeviceSlice<Digest>,
204204
results: &mut DeviceSlice<Digest>,
@@ -228,41 +228,68 @@ pub fn gather_merkle_paths_device(
228228
GatherMerklePathsFunction::default().launch(&config, &args)
229229
}
230230

231-
pub fn gather_merkle_paths_host(
232-
indexes: &[u32],
233-
values: &[Digest],
234-
results: &mut [Digest],
231+
cuda_kernel!(
232+
GatherRowsAndMerklePaths,
233+
ab_gather_rows_and_merkle_paths_kernel(
234+
indexes: *const u32,
235+
indexes_count: u32,
236+
bit_reverse_indexes: bool,
237+
values: *const BF,
238+
log_rows_per_leaf: u32,
239+
cols_count: u32,
240+
log_total_leaves_count: u32,
241+
leaf_values: MutPtrAndStride<BF>,
242+
tree_bottom: *const Digest,
243+
layers_count: u32,
244+
merkle_paths: *mut Digest,
245+
)
246+
);
247+
248+
pub fn gather_rows_and_merkle_paths(
249+
indexes: &DeviceSlice<u32>,
250+
bit_reverse_indexes: bool,
251+
values: &DeviceSlice<BF>,
252+
log_rows_per_index: u32,
253+
leaf_values: &mut (impl DeviceMatrixChunkMutImpl<BF> + ?Sized),
254+
tree_bottom: &DeviceSlice<Digest>,
255+
merkle_paths: &mut DeviceSlice<Digest>,
235256
layers_count: u32,
236-
) {
237-
assert!(indexes.len() <= u32::MAX as usize);
238-
let indexes_count = indexes.len() as u32;
239-
let values_count = values.len();
240-
assert!(values_count.is_power_of_two());
241-
let log_values_count = values_count.trailing_zeros();
242-
assert_ne!(log_values_count, 0);
243-
let log_leaves_count = log_values_count - 1;
244-
assert!(layers_count < log_leaves_count);
245-
assert_eq!(indexes.len() * layers_count as usize, results.len());
246-
for layer_index in 0..layers_count {
247-
let layer_offset =
248-
(1 << (log_leaves_count + 1)) - (1 << (log_leaves_count + 1 - layer_index));
249-
for (idx, &leaf_index) in indexes.iter().enumerate() {
250-
let hash_offset = ((leaf_index >> layer_index) ^ 1) as usize;
251-
let dst_index = idx + (layer_index * indexes_count) as usize;
252-
let src_index = layer_offset + hash_offset;
253-
results[dst_index] = values[src_index];
254-
}
255-
}
256-
/*
257-
const unsigned leaf_index = indexes[idx];
258-
const unsigned layer_index = blockIdx.y;
259-
const unsigned layer_offset = ((1u << log_leaves_count + 1) - (1u << log_leaves_count + 1 - layer_index)) * STATE_SIZE;
260-
const unsigned hash_offset = (leaf_index >> layer_index ^ 1) * STATE_SIZE;
261-
const unsigned element_offset = threadIdx.x;
262-
const unsigned src_index = layer_offset + hash_offset + element_offset;
263-
const unsigned dst_index = layer_index * indexes_count * STATE_SIZE + idx * STATE_SIZE + element_offset;
264-
results[dst_index] = values[src_index];
265-
*/
257+
stream: &CudaStream,
258+
) -> CudaResult<()> {
259+
let indexes_len = indexes.len();
260+
let values_len = values.len();
261+
let cols_count = leaf_values.cols();
262+
assert_eq!(values_len % cols_count, 0);
263+
let log_rows_count = (values_len / cols_count).trailing_zeros();
264+
assert_eq!(leaf_values.rows(), indexes_len << log_rows_per_index);
265+
assert!(indexes_len <= u32::MAX as usize);
266+
let indexes_count = indexes_len as u32;
267+
assert!(layers_count >= LOG_WARP_SIZE);
268+
assert_eq!(indexes_len * layers_count as usize, merkle_paths.len());
269+
let cols_count = cols_count as u32;
270+
let log_total_leaves_count = log_rows_count - log_rows_per_index;
271+
let grid_dim = indexes_count;
272+
let block_dim = WARP_SIZE;
273+
let config = CudaLaunchConfig::basic(grid_dim, block_dim, stream);
274+
let indexes = indexes.as_ptr();
275+
let values = values.as_ptr();
276+
let leaf_values = leaf_values.as_mut_ptr_and_stride();
277+
let tree_bottom = tree_bottom.as_ptr();
278+
let merkle_paths = merkle_paths.as_mut_ptr();
279+
let args = GatherRowsAndMerklePathsArguments::new(
280+
indexes,
281+
indexes_count,
282+
bit_reverse_indexes,
283+
values,
284+
log_rows_per_index,
285+
cols_count,
286+
log_total_leaves_count,
287+
leaf_values,
288+
tree_bottom,
289+
layers_count,
290+
merkle_paths,
291+
);
292+
GatherRowsAndMerklePathsFunction::default().launch(&config, &args)
266293
}
267294

268295
pub fn merkle_tree_cap(
@@ -546,7 +573,7 @@ mod tests {
546573
let mut results_device = DeviceAllocation::alloc(results_host.len()).unwrap();
547574
memory_copy_async(&mut indexes_device, &indexes_host, &stream).unwrap();
548575
memory_copy_async(&mut values_device, &values_host, &stream).unwrap();
549-
gather_merkle_paths_device(
576+
super::gather_merkle_paths(
550577
&indexes_device,
551578
&values_device,
552579
&mut results_device,

gpu_prover/src/execution/gpu_worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ fn get_trees_cache_mode(_circuit_type: CircuitType, _context: &ProverContext) ->
6262
// },
6363
// _ => TreesCacheMode::CacheFull,
6464
// }
65-
TreesCacheMode::CacheNone
65+
TreesCacheMode::CachePatrial
6666
}
6767

6868
fn gpu_worker(

gpu_prover/src/prover/memory.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ pub(crate) fn commit_memory<'a, A: GoodAllocator>(
196196
}
197197
},
198198
}
199+
drop(evaluations);
199200
memory_holder.make_evaluations_sum_to_zero_extend_and_commit(context)?;
200201
let src_tree_cap_accessors = memory_holder.get_tree_caps_accessors();
201202
let mut tree_caps = Box::new(None);

0 commit comments

Comments
 (0)