Skip to content

Commit 45e6af7

Browse files
authored
fix(gpu_prover): improve memory allocation/deallocation ordering (#6)
## What ❔ Enforce allocation order for stage1&2 allocations. ## Why ❔ Explicit ordering helps to lower memory fragmentation and removes allocation peak in stage 2. ## Is this a breaking change? - [ ] Yes - [x] No ## Checklist - [x] PR title corresponds to the body of PR (we generate changelog entries from PRs). - [x] Code has been formatted.
1 parent a98a8b7 commit 45e6af7

File tree

5 files changed

+204
-78
lines changed

5 files changed

+204
-78
lines changed

gpu_prover/src/prover/proof.rs

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,51 @@ where
202202
setup.ensure_commitment_produced(context)?;
203203
setup_range.end(stream)?;
204204

205-
// stage 1
206-
let stage_1_range = device_tracing::Range::new("stage_1")?;
207-
stage_1_range.start(stream)?;
208-
let mut stage_1_output = StageOneOutput::new(
205+
let mut stage_1_output = StageOneOutput::allocate_trace_holders(
209206
circuit,
210-
setup,
211-
tracing_data_transfer,
212207
log_lde_factor,
213208
log_tree_cap_size,
209+
context,
210+
)?;
211+
#[cfg(feature = "print_gpu_mem_usage")]
212+
{
213+
print!("after stage_1.allocate_trace_holders ");
214+
context.print_mem_pool_stats()?;
215+
}
216+
217+
let mut stage_2_output = StageTwoOutput::allocate_trace_evaluations(
218+
circuit,
219+
log_lde_factor,
220+
log_tree_cap_size,
221+
context,
222+
)?;
223+
#[cfg(feature = "print_gpu_mem_usage")]
224+
{
225+
print!("after stage_2.allocate_trace_evaluations ");
226+
context.print_mem_pool_stats()?;
227+
}
228+
229+
// witness_generation
230+
let witness_generation_range = device_tracing::Range::new("witness_generation")?;
231+
witness_generation_range.start(stream)?;
232+
stage_1_output.generate_witness(
233+
circuit,
234+
setup,
235+
tracing_data_transfer,
214236
circuit_sequence,
215237
context,
216238
)?;
239+
witness_generation_range.end(stream)?;
240+
#[cfg(feature = "print_gpu_mem_usage")]
241+
{
242+
print!("after generate_witness ");
243+
context.print_mem_pool_stats()?;
244+
}
245+
246+
// stage 1
247+
let stage_1_range = device_tracing::Range::new("stage_1")?;
248+
stage_1_range.start(stream)?;
249+
stage_1_output.commit_witness(circuit, context)?;
217250
stage_1_range.end(stream)?;
218251
#[cfg(feature = "print_gpu_mem_usage")]
219252
{
@@ -222,9 +255,6 @@ where
222255
}
223256

224257
setup.trace_holder.produce_tree_caps(context)?;
225-
stage_1_output.witness_holder.produce_tree_caps(context)?;
226-
stage_1_output.memory_holder.produce_tree_caps(context)?;
227-
stage_1_output.produce_public_inputs(circuit, context)?;
228258

229259
// seed
230260
let seed = initialize_seed::<C>(
@@ -241,18 +271,15 @@ where
241271
// stage 2
242272
let stage_2_range = device_tracing::Range::new("stage_2")?;
243273
stage_2_range.start(stream)?;
244-
let stage_2_output = StageTwoOutput::new(
274+
stage_2_output.generate(
245275
seed.clone(),
246276
circuit,
247277
&cached_data_values,
248278
setup,
249279
&mut stage_1_output,
250-
log_lde_factor,
251-
log_tree_cap_size,
252280
context,
253281
)?;
254282
stage_2_range.end(stream)?;
255-
256283
#[cfg(feature = "print_gpu_mem_usage")]
257284
{
258285
print!("after stage_2 ");
@@ -395,8 +422,8 @@ where
395422
let is_finished_event = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?;
396423
is_finished_event.record(stream)?;
397424

398-
callbacks.extend(stage_1_output.callbacks);
399-
callbacks.extend(stage_2_output.callbacks);
425+
callbacks.extend(stage_1_output.callbacks.unwrap());
426+
callbacks.extend(stage_2_output.callbacks.unwrap());
400427
callbacks.extend(stage_3_output.callbacks);
401428
callbacks.extend(stage_4_output.callbacks);
402429
callbacks.extend(stage_5_output.callbacks);
@@ -411,20 +438,20 @@ where
411438
memory_tree_caps: stage_1_output.memory_holder.get_tree_caps(),
412439
setup_tree_caps: setup.trace_holder.get_tree_caps(),
413440
stage_2_tree_caps: stage_2_output.trace_holder.get_tree_caps(),
414-
stage_2_last_row: stage_2_output.last_row.clone(),
441+
stage_2_last_row: stage_2_output.last_row.unwrap(),
415442
stage_2_offset_for_memory_grand_product_poly: stage_2_output.offset_for_grand_product_poly,
416443
stage_2_offset_for_delegation_argument_poly: stage_2_output
417444
.offset_for_sum_over_delegation_poly,
418445
quotient_tree_caps: stage_3_output.trace_holder.get_tree_caps(),
419-
evaluations_at_random_points: stage_4_output.values_at_z.clone(),
446+
evaluations_at_random_points: stage_4_output.values_at_z,
420447
deep_poly_caps: stage_4_output.trace_holder.get_tree_caps(),
421448
intermediate_fri_oracle_caps: stage_5_output
422449
.fri_oracles
423-
.iter()
424-
.map(|o| o.tree_caps.clone())
450+
.into_iter()
451+
.map(|o| o.tree_caps)
425452
.collect_vec(),
426-
last_fri_step_plain_leaf_values: stage_5_output.last_fri_step_plain_leaf_values.clone(),
427-
final_monomial_form: stage_5_output.final_monomials.clone(),
453+
last_fri_step_plain_leaf_values: stage_5_output.last_fri_step_plain_leaf_values,
454+
final_monomial_form: stage_5_output.final_monomials,
428455
pow_output,
429456
queries_output,
430457
circuit_sequence: circuit_sequence as u16,

gpu_prover/src/prover/stage_1.rs

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,22 @@ pub(crate) struct StageOneOutput<'a, C: ProverContext> {
2727
pub witness_holder: TraceHolder<BF, C>,
2828
pub memory_holder: TraceHolder<BF, C>,
2929
pub generic_lookup_mapping: Option<C::Allocation<u32>>,
30-
pub callbacks: Callbacks<'a>,
30+
pub callbacks: Option<Callbacks<'a>>,
3131
pub public_inputs: Option<Arc<Mutex<Vec<BF>>>>,
3232
}
3333

3434
impl<'a, C: ProverContext> StageOneOutput<'a, C> {
35-
pub fn new(
35+
pub fn allocate_trace_holders(
3636
circuit: &CompiledCircuitArtifact<BF>,
37-
setup: &SetupPrecomputations<C>,
38-
tracing_data_transfer: TracingDataTransfer<'a, C>,
3937
log_lde_factor: u32,
4038
log_tree_cap_size: u32,
41-
circuit_sequence: usize,
4239
context: &C,
4340
) -> CudaResult<Self> {
4441
let trace_len = circuit.trace_len;
4542
assert!(trace_len.is_power_of_two());
4643
let log_domain_size = trace_len.trailing_zeros();
47-
let witness_subtree = &circuit.witness_layout;
48-
let witness_columns_count = witness_subtree.total_width;
49-
let mut witness_holder = TraceHolder::new(
44+
let witness_columns_count = circuit.witness_layout.total_width;
45+
let witness_holder = TraceHolder::new(
5046
log_domain_size,
5147
log_lde_factor,
5248
0,
@@ -55,9 +51,8 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
5551
true,
5652
context,
5753
)?;
58-
let memory_subtree = &circuit.memory_layout;
59-
let memory_columns_count = memory_subtree.total_width;
60-
let mut memory_holder = TraceHolder::new(
54+
let memory_columns_count = circuit.memory_layout.total_width;
55+
let memory_holder = TraceHolder::new(
6156
log_domain_size,
6257
log_lde_factor,
6358
0,
@@ -66,6 +61,28 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
6661
true,
6762
context,
6863
)?;
64+
Ok(Self {
65+
witness_holder,
66+
memory_holder,
67+
generic_lookup_mapping: None,
68+
callbacks: None,
69+
public_inputs: None,
70+
})
71+
}
72+
73+
pub fn generate_witness(
74+
&mut self,
75+
circuit: &CompiledCircuitArtifact<BF>,
76+
setup: &SetupPrecomputations<C>,
77+
tracing_data_transfer: TracingDataTransfer<'a, C>,
78+
circuit_sequence: usize,
79+
context: &C,
80+
) -> CudaResult<()> {
81+
let trace_len = circuit.trace_len;
82+
assert!(trace_len.is_power_of_two());
83+
let log_domain_size = trace_len.trailing_zeros();
84+
let witness_subtree = &circuit.witness_layout;
85+
let memory_subtree = &circuit.memory_layout;
6986
let generic_lookup_mapping_size = witness_subtree.width_3_lookups.len() << log_domain_size;
7087
let mut generic_lookup_mapping = context.alloc(generic_lookup_mapping_size)?;
7188
let TracingDataTransfer {
@@ -75,6 +92,7 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
7592
transfer,
7693
} = tracing_data_transfer;
7794
transfer.ensure_transferred(context)?;
95+
self.callbacks = Some(transfer.callbacks);
7896
let stream = context.get_exec_stream();
7997
assert_eq!(COMMON_TABLE_WIDTH, 3);
8098
assert_eq!(NUM_COLUMNS_FOR_COMMON_TABLE_WIDTH_SETUP, 4);
@@ -100,6 +118,8 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
100118
+ timestamp_range_check_multiplicities_columns.num_elements,
101119
generic_multiplicities_columns.start
102120
);
121+
let witness_holder = &mut self.witness_holder;
122+
let memory_holder = &mut self.memory_holder;
103123
match data_device {
104124
TracingDataDevice::Main {
105125
setup_and_teardown,
@@ -172,15 +192,26 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
172192
trace_len,
173193
context,
174194
)?;
175-
memory_holder.make_evaluations_sum_to_zero_extend_and_commit(context)?;
176-
witness_holder.make_evaluations_sum_to_zero_extend_and_commit(context)?;
177-
Ok(Self {
178-
witness_holder,
179-
memory_holder,
180-
generic_lookup_mapping: Some(generic_lookup_mapping),
181-
callbacks: transfer.callbacks,
182-
public_inputs: None,
183-
})
195+
self.generic_lookup_mapping = Some(generic_lookup_mapping);
196+
Ok(())
197+
}
198+
199+
pub fn commit_witness(
200+
&mut self,
201+
circuit: &'a CompiledCircuitArtifact<BF>,
202+
context: &C,
203+
) -> CudaResult<()>
204+
where
205+
C::HostAllocator: 'a,
206+
{
207+
self.memory_holder
208+
.make_evaluations_sum_to_zero_extend_and_commit(context)?;
209+
self.memory_holder.produce_tree_caps(context)?;
210+
self.witness_holder
211+
.make_evaluations_sum_to_zero_extend_and_commit(context)?;
212+
self.witness_holder.produce_tree_caps(context)?;
213+
self.produce_public_inputs(circuit, context)?;
214+
Ok(())
184215
}
185216

186217
pub fn produce_public_inputs(
@@ -257,7 +288,10 @@ impl<'a, C: ProverContext> StageOneOutput<'a, C> {
257288
guard.extend(first_row_public_inputs);
258289
guard.extend(one_before_last_row_public_inputs);
259290
};
260-
self.callbacks.schedule(function, stream)?;
291+
self.callbacks
292+
.as_mut()
293+
.unwrap()
294+
.schedule(function, stream)?;
261295
self.public_inputs = Some(public_inputs);
262296
Ok(())
263297
}

0 commit comments

Comments
 (0)