Skip to content

Commit 1de6ca0

Browse files
committed
feat(gpu_prover): multi-GPU support
1 parent 45e6af7 commit 1de6ca0

File tree

3 files changed

+188
-96
lines changed

3 files changed

+188
-96
lines changed

circuit_defs/prover_examples/src/gpu.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ use trace_and_split::{
4848
use crate::{NUM_QUERIES, POW_BITS};
4949

5050
pub fn create_default_prover_context<'a>() -> MemPoolProverContext<'a> {
51-
let mut prover_context_config = ProverContextConfig::default();
5251
// allocate 1k 4 MB chunks (so around 4GB of host ram).
52+
MemPoolProverContext::initialize_host_allocator(22, 1 << 10).unwrap();
53+
let mut prover_context_config = ProverContextConfig::default();
5354
prover_context_config.allocation_block_log_size = 22;
54-
prover_context_config.host_allocated_blocks = 512;
5555

5656
let prover_context = MemPoolProverContext::new(&prover_context_config).unwrap();
5757
prover_context

gpu_prover/src/prover/context.rs

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::allocator::host::ConcurrentStaticHostAllocator;
22
use crate::context::Context;
3+
use era_cudart::device::{get_device, set_device};
34
use era_cudart::memory::{memory_get_info, CudaHostAllocFlags, HostAllocation};
45
use era_cudart::memory_pools::{
56
AttributeHandler, CudaMemPoolAttributeU64, CudaOwnedMemPool, DevicePoolAllocation,
@@ -18,7 +19,6 @@ static DEFAULT_STREAM: CudaStream = CudaStream::DEFAULT;
1819
pub struct ProverContextConfig {
1920
pub powers_of_w_coarse_log_count: u32,
2021
pub allocation_block_log_size: u32,
21-
pub host_allocated_blocks: usize,
2222
pub device_slack_blocks: usize,
2323
}
2424

@@ -27,7 +27,6 @@ impl Default for ProverContextConfig {
2727
Self {
2828
powers_of_w_coarse_log_count: 12,
2929
allocation_block_log_size: 22,
30-
host_allocated_blocks: 1 << 10,
3130
device_slack_blocks: 1,
3231
}
3332
}
@@ -36,6 +35,12 @@ impl Default for ProverContextConfig {
3635
pub trait ProverContext {
3736
type HostAllocator: GoodAllocator;
3837
type Allocation<T: Sync>: DerefMut<Target = DeviceSlice<T>> + CudaSliceMut<T> + Sync;
38+
fn initialize_host_allocator(
39+
allocation_block_log_size: u32,
40+
blocks_count: usize,
41+
) -> CudaResult<()>;
42+
fn get_device_id(&self) -> i32;
43+
fn switch_to_device(&self) -> CudaResult<()>;
3944
fn get_exec_stream(&self) -> &CudaStream;
4045
fn get_h2d_stream(&self) -> &CudaStream;
4146
fn alloc<T: Sync>(&self, size: usize) -> CudaResult<Self::Allocation<T>>;
@@ -64,31 +69,18 @@ pub struct MemPoolProverContext<'a> {
6469
pub(crate) exec_stream: CudaStream,
6570
pub(crate) h2d_stream: CudaStream,
6671
pub(crate) mem_pool: CudaOwnedMemPool,
72+
pub(crate) device_id: i32,
6773
_phantom: PhantomData<&'a ()>,
6874
}
6975

7076
impl<'a> MemPoolProverContext<'a> {
7177
pub fn new(config: &ProverContextConfig) -> CudaResult<Self> {
72-
if ConcurrentStaticHostAllocator::is_initialized_global() {
73-
println!("reusing existing static host allocator");
74-
} else {
75-
let host_allocation_size =
76-
config.host_allocated_blocks << config.allocation_block_log_size;
77-
let host_allocation =
78-
HostAllocation::alloc(host_allocation_size, CudaHostAllocFlags::DEFAULT)?;
79-
ConcurrentStaticHostAllocator::initialize_global(
80-
host_allocation,
81-
config.allocation_block_log_size,
82-
);
83-
println!(
84-
"initialized static host allocator with {} GB",
85-
host_allocation_size as f32 / 1024.0 / 1024.0 / 1024.0
86-
);
87-
}
78+
assert!(ConcurrentStaticHostAllocator::is_initialized_global());
8879
let inner = Context::create(12)?;
8980
let exec_stream = CudaStream::create()?;
9081
let h2d_stream = CudaStream::create()?;
91-
let mem_pool = CudaOwnedMemPool::create_for_device(0)?;
82+
let device_id = get_device()?;
83+
let mem_pool = CudaOwnedMemPool::create_for_device(device_id)?;
9284
mem_pool.set_attribute(CudaMemPoolAttributeU64::AttrReleaseThreshold, u64::MAX)?;
9385
let (free, _) = memory_get_info()?;
9486
let mut size = (free >> config.allocation_block_log_size) - config.device_slack_blocks;
@@ -123,7 +115,7 @@ impl<'a> MemPoolProverContext<'a> {
123115
}
124116
}
125117
println!(
126-
"GPU usable memory: {} GB",
118+
"initialized GPU memory pool for device ID {device_id} with {} GB of usable memory",
127119
(size << config.allocation_block_log_size) as f32 / 1024.0 / 1024.0 / 1024.0
128120
);
129121
mem_pool.set_attribute(CudaMemPoolAttributeU64::AttrUsedMemHigh, 0)?;
@@ -133,6 +125,7 @@ impl<'a> MemPoolProverContext<'a> {
133125
exec_stream,
134126
h2d_stream,
135127
mem_pool,
128+
device_id,
136129
_phantom: PhantomData,
137130
};
138131
Ok(context)
@@ -143,6 +136,32 @@ impl<'a> ProverContext for MemPoolProverContext<'a> {
143136
type HostAllocator = ConcurrentStaticHostAllocator;
144137
type Allocation<T: Sync> = DevicePoolAllocation<'a, T>;
145138

139+
fn initialize_host_allocator(
140+
allocation_block_log_size: u32,
141+
blocks_count: usize,
142+
) -> CudaResult<()> {
143+
let host_allocation_size = blocks_count << allocation_block_log_size;
144+
let host_allocation =
145+
HostAllocation::alloc(host_allocation_size, CudaHostAllocFlags::DEFAULT)?;
146+
ConcurrentStaticHostAllocator::initialize_global(
147+
host_allocation,
148+
allocation_block_log_size,
149+
);
150+
println!(
151+
"initialized ConcurrentStaticHostAllocator with {} GB",
152+
host_allocation_size as f32 / 1024.0 / 1024.0 / 1024.0
153+
);
154+
Ok(())
155+
}
156+
157+
fn get_device_id(&self) -> i32 {
158+
self.device_id
159+
}
160+
161+
fn switch_to_device(&self) -> CudaResult<()> {
162+
set_device(self.device_id)
163+
}
164+
146165
fn get_exec_stream(&self) -> &CudaStream {
147166
&self.exec_stream
148167
}
@@ -161,8 +180,9 @@ impl<'a> ProverContext for MemPoolProverContext<'a> {
161180
let result: CudaResult<Self::Allocation<T>> = unsafe { std::mem::transmute(result) };
162181
if result.is_err() {
163182
println!(
164-
"failed to allocate {} bytes, currently allocated {} bytes",
183+
"failed to allocate {} bytes from GPU memory pool of device ID {}, currently allocated {} bytes",
165184
size * size_of::<T>(),
185+
self.device_id,
166186
self.get_used_mem_current()?
167187
);
168188
}

0 commit comments

Comments
 (0)