11use crate :: allocator:: host:: ConcurrentStaticHostAllocator ;
22use crate :: context:: Context ;
3+ use era_cudart:: device:: { get_device, set_device} ;
34use era_cudart:: memory:: { memory_get_info, CudaHostAllocFlags , HostAllocation } ;
45use era_cudart:: memory_pools:: {
56 AttributeHandler , CudaMemPoolAttributeU64 , CudaOwnedMemPool , DevicePoolAllocation ,
@@ -18,7 +19,6 @@ static DEFAULT_STREAM: CudaStream = CudaStream::DEFAULT;
1819pub struct ProverContextConfig {
1920 pub powers_of_w_coarse_log_count : u32 ,
2021 pub allocation_block_log_size : u32 ,
21- pub host_allocated_blocks : usize ,
2222 pub device_slack_blocks : usize ,
2323}
2424
@@ -27,7 +27,6 @@ impl Default for ProverContextConfig {
2727 Self {
2828 powers_of_w_coarse_log_count : 12 ,
2929 allocation_block_log_size : 22 ,
30- host_allocated_blocks : 1 << 10 ,
3130 device_slack_blocks : 1 ,
3231 }
3332 }
@@ -36,6 +35,12 @@ impl Default for ProverContextConfig {
3635pub trait ProverContext {
3736 type HostAllocator : GoodAllocator ;
3837 type Allocation < T : Sync > : DerefMut < Target = DeviceSlice < T > > + CudaSliceMut < T > + Sync ;
38+ fn initialize_host_allocator (
39+ allocation_block_log_size : u32 ,
40+ blocks_count : usize ,
41+ ) -> CudaResult < ( ) > ;
42+ fn get_device_id ( & self ) -> i32 ;
43+ fn switch_to_device ( & self ) -> CudaResult < ( ) > ;
3944 fn get_exec_stream ( & self ) -> & CudaStream ;
4045 fn get_h2d_stream ( & self ) -> & CudaStream ;
4146 fn alloc < T : Sync > ( & self , size : usize ) -> CudaResult < Self :: Allocation < T > > ;
@@ -64,31 +69,18 @@ pub struct MemPoolProverContext<'a> {
6469 pub ( crate ) exec_stream : CudaStream ,
6570 pub ( crate ) h2d_stream : CudaStream ,
6671 pub ( crate ) mem_pool : CudaOwnedMemPool ,
72+ pub ( crate ) device_id : i32 ,
6773 _phantom : PhantomData < & ' a ( ) > ,
6874}
6975
7076impl < ' a > MemPoolProverContext < ' a > {
7177 pub fn new ( config : & ProverContextConfig ) -> CudaResult < Self > {
72- if ConcurrentStaticHostAllocator :: is_initialized_global ( ) {
73- println ! ( "reusing existing static host allocator" ) ;
74- } else {
75- let host_allocation_size =
76- config. host_allocated_blocks << config. allocation_block_log_size ;
77- let host_allocation =
78- HostAllocation :: alloc ( host_allocation_size, CudaHostAllocFlags :: DEFAULT ) ?;
79- ConcurrentStaticHostAllocator :: initialize_global (
80- host_allocation,
81- config. allocation_block_log_size ,
82- ) ;
83- println ! (
84- "initialized static host allocator with {} GB" ,
85- host_allocation_size as f32 / 1024.0 / 1024.0 / 1024.0
86- ) ;
87- }
78+ assert ! ( ConcurrentStaticHostAllocator :: is_initialized_global( ) ) ;
8879 let inner = Context :: create ( 12 ) ?;
8980 let exec_stream = CudaStream :: create ( ) ?;
9081 let h2d_stream = CudaStream :: create ( ) ?;
91- let mem_pool = CudaOwnedMemPool :: create_for_device ( 0 ) ?;
82+ let device_id = get_device ( ) ?;
83+ let mem_pool = CudaOwnedMemPool :: create_for_device ( device_id) ?;
9284 mem_pool. set_attribute ( CudaMemPoolAttributeU64 :: AttrReleaseThreshold , u64:: MAX ) ?;
9385 let ( free, _) = memory_get_info ( ) ?;
9486 let mut size = ( free >> config. allocation_block_log_size ) - config. device_slack_blocks ;
@@ -123,7 +115,7 @@ impl<'a> MemPoolProverContext<'a> {
123115 }
124116 }
125117 println ! (
126- "GPU usable memory: {} GB " ,
118+ "initialized GPU memory pool for device ID {device_id} with {} GB of usable memory " ,
127119 ( size << config. allocation_block_log_size) as f32 / 1024.0 / 1024.0 / 1024.0
128120 ) ;
129121 mem_pool. set_attribute ( CudaMemPoolAttributeU64 :: AttrUsedMemHigh , 0 ) ?;
@@ -133,6 +125,7 @@ impl<'a> MemPoolProverContext<'a> {
133125 exec_stream,
134126 h2d_stream,
135127 mem_pool,
128+ device_id,
136129 _phantom : PhantomData ,
137130 } ;
138131 Ok ( context)
@@ -143,6 +136,32 @@ impl<'a> ProverContext for MemPoolProverContext<'a> {
143136 type HostAllocator = ConcurrentStaticHostAllocator ;
144137 type Allocation < T : Sync > = DevicePoolAllocation < ' a , T > ;
145138
139+ fn initialize_host_allocator (
140+ allocation_block_log_size : u32 ,
141+ blocks_count : usize ,
142+ ) -> CudaResult < ( ) > {
143+ let host_allocation_size = blocks_count << allocation_block_log_size;
144+ let host_allocation =
145+ HostAllocation :: alloc ( host_allocation_size, CudaHostAllocFlags :: DEFAULT ) ?;
146+ ConcurrentStaticHostAllocator :: initialize_global (
147+ host_allocation,
148+ allocation_block_log_size,
149+ ) ;
150+ println ! (
151+ "initialized ConcurrentStaticHostAllocator with {} GB" ,
152+ host_allocation_size as f32 / 1024.0 / 1024.0 / 1024.0
153+ ) ;
154+ Ok ( ( ) )
155+ }
156+
157+ fn get_device_id ( & self ) -> i32 {
158+ self . device_id
159+ }
160+
161+ fn switch_to_device ( & self ) -> CudaResult < ( ) > {
162+ set_device ( self . device_id )
163+ }
164+
146165 fn get_exec_stream ( & self ) -> & CudaStream {
147166 & self . exec_stream
148167 }
@@ -161,8 +180,9 @@ impl<'a> ProverContext for MemPoolProverContext<'a> {
161180 let result: CudaResult < Self :: Allocation < T > > = unsafe { std:: mem:: transmute ( result) } ;
162181 if result. is_err ( ) {
163182 println ! (
164- "failed to allocate {} bytes, currently allocated {} bytes" ,
183+ "failed to allocate {} bytes from GPU memory pool of device ID {} , currently allocated {} bytes" ,
165184 size * size_of:: <T >( ) ,
185+ self . device_id,
166186 self . get_used_mem_current( ) ?
167187 ) ;
168188 }
0 commit comments