@@ -198,7 +198,7 @@ cuda_kernel!(
198198 )
199199) ;
200200
201- pub fn gather_merkle_paths_device (
201+ pub fn gather_merkle_paths (
202202 indexes : & DeviceSlice < u32 > ,
203203 values : & DeviceSlice < Digest > ,
204204 results : & mut DeviceSlice < Digest > ,
@@ -228,41 +228,68 @@ pub fn gather_merkle_paths_device(
228228 GatherMerklePathsFunction :: default ( ) . launch ( & config, & args)
229229}
230230
231- pub fn gather_merkle_paths_host (
232- indexes : & [ u32 ] ,
233- values : & [ Digest ] ,
234- results : & mut [ Digest ] ,
231+ cuda_kernel ! (
232+ GatherRowsAndMerklePaths ,
233+ ab_gather_rows_and_merkle_paths_kernel(
234+ indexes: * const u32 ,
235+ indexes_count: u32 ,
236+ bit_reverse_indexes: bool ,
237+ values: * const BF ,
238+ log_rows_per_leaf: u32 ,
239+ cols_count: u32 ,
240+ log_total_leaves_count: u32 ,
241+ leaf_values: MutPtrAndStride <BF >,
242+ tree_bottom: * const Digest ,
243+ layers_count: u32 ,
244+ merkle_paths: * mut Digest ,
245+ )
246+ ) ;
247+
248+ pub fn gather_rows_and_merkle_paths (
249+ indexes : & DeviceSlice < u32 > ,
250+ bit_reverse_indexes : bool ,
251+ values : & DeviceSlice < BF > ,
252+ log_rows_per_index : u32 ,
253+ leaf_values : & mut ( impl DeviceMatrixChunkMutImpl < BF > + ?Sized ) ,
254+ tree_bottom : & DeviceSlice < Digest > ,
255+ merkle_paths : & mut DeviceSlice < Digest > ,
235256 layers_count : u32 ,
236- ) {
237- assert ! ( indexes. len( ) <= u32 :: MAX as usize ) ;
238- let indexes_count = indexes. len ( ) as u32 ;
239- let values_count = values. len ( ) ;
240- assert ! ( values_count. is_power_of_two( ) ) ;
241- let log_values_count = values_count. trailing_zeros ( ) ;
242- assert_ne ! ( log_values_count, 0 ) ;
243- let log_leaves_count = log_values_count - 1 ;
244- assert ! ( layers_count < log_leaves_count) ;
245- assert_eq ! ( indexes. len( ) * layers_count as usize , results. len( ) ) ;
246- for layer_index in 0 ..layers_count {
247- let layer_offset =
248- ( 1 << ( log_leaves_count + 1 ) ) - ( 1 << ( log_leaves_count + 1 - layer_index) ) ;
249- for ( idx, & leaf_index) in indexes. iter ( ) . enumerate ( ) {
250- let hash_offset = ( ( leaf_index >> layer_index) ^ 1 ) as usize ;
251- let dst_index = idx + ( layer_index * indexes_count) as usize ;
252- let src_index = layer_offset + hash_offset;
253- results[ dst_index] = values[ src_index] ;
254- }
255- }
256- /*
257- const unsigned leaf_index = indexes[idx];
258- const unsigned layer_index = blockIdx.y;
259- const unsigned layer_offset = ((1u << log_leaves_count + 1) - (1u << log_leaves_count + 1 - layer_index)) * STATE_SIZE;
260- const unsigned hash_offset = (leaf_index >> layer_index ^ 1) * STATE_SIZE;
261- const unsigned element_offset = threadIdx.x;
262- const unsigned src_index = layer_offset + hash_offset + element_offset;
263- const unsigned dst_index = layer_index * indexes_count * STATE_SIZE + idx * STATE_SIZE + element_offset;
264- results[dst_index] = values[src_index];
265- */
257+ stream : & CudaStream ,
258+ ) -> CudaResult < ( ) > {
259+ let indexes_len = indexes. len ( ) ;
260+ let values_len = values. len ( ) ;
261+ let cols_count = leaf_values. cols ( ) ;
262+ assert_eq ! ( values_len % cols_count, 0 ) ;
263+ let log_rows_count = ( values_len / cols_count) . trailing_zeros ( ) ;
264+ assert_eq ! ( leaf_values. rows( ) , indexes_len << log_rows_per_index) ;
265+ assert ! ( indexes_len <= u32 :: MAX as usize ) ;
266+ let indexes_count = indexes_len as u32 ;
267+ assert ! ( layers_count >= LOG_WARP_SIZE ) ;
268+ assert_eq ! ( indexes_len * layers_count as usize , merkle_paths. len( ) ) ;
269+ let cols_count = cols_count as u32 ;
270+ let log_total_leaves_count = log_rows_count - log_rows_per_index;
271+ let grid_dim = indexes_count;
272+ let block_dim = WARP_SIZE ;
273+ let config = CudaLaunchConfig :: basic ( grid_dim, block_dim, stream) ;
274+ let indexes = indexes. as_ptr ( ) ;
275+ let values = values. as_ptr ( ) ;
276+ let leaf_values = leaf_values. as_mut_ptr_and_stride ( ) ;
277+ let tree_bottom = tree_bottom. as_ptr ( ) ;
278+ let merkle_paths = merkle_paths. as_mut_ptr ( ) ;
279+ let args = GatherRowsAndMerklePathsArguments :: new (
280+ indexes,
281+ indexes_count,
282+ bit_reverse_indexes,
283+ values,
284+ log_rows_per_index,
285+ cols_count,
286+ log_total_leaves_count,
287+ leaf_values,
288+ tree_bottom,
289+ layers_count,
290+ merkle_paths,
291+ ) ;
292+ GatherRowsAndMerklePathsFunction :: default ( ) . launch ( & config, & args)
266293}
267294
268295pub fn merkle_tree_cap (
@@ -546,7 +573,7 @@ mod tests {
546573 let mut results_device = DeviceAllocation :: alloc ( results_host. len ( ) ) . unwrap ( ) ;
547574 memory_copy_async ( & mut indexes_device, & indexes_host, & stream) . unwrap ( ) ;
548575 memory_copy_async ( & mut values_device, & values_host, & stream) . unwrap ( ) ;
549- gather_merkle_paths_device (
576+ super :: gather_merkle_paths (
550577 & indexes_device,
551578 & values_device,
552579 & mut results_device,
0 commit comments