Skip to content

Commit 4be65dc

Browse files
committed
Optimize memory usage in eval_mle_at_point_blocking
1 parent fc1d1a3 commit 4be65dc

File tree

2 files changed

+114
-37
lines changed

2 files changed

+114
-37
lines changed

crates/multilinear/src/eval.rs

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
use hypercube_alloc::{buffer, Buffer, CpuBackend};
1+
use hypercube_alloc::{Buffer, CpuBackend};
22
use hypercube_tensor::{Dimensions, Tensor};
33
use p3_field::{AbstractExtensionField, AbstractField};
44
use rayon::prelude::*;
5+
use std::sync::{Arc, Mutex};
56

67
use crate::{partial_lagrange_blocking, Point};
78

@@ -16,22 +17,70 @@ pub(crate) fn eval_mle_at_point_blocking<
1617
let mut sizes = mle.sizes().to_vec();
1718
sizes.remove(0);
1819
let dimensions = Dimensions::try_from(sizes).unwrap();
19-
let mut dst = Tensor { storage: buffer![], dimensions };
20-
let total_len = dst.total_len();
21-
let dot_products = mle
22-
.as_buffer()
20+
let total_len = dimensions.total_len();
21+
22+
// Pre-allocation of the result buffer
23+
let result = Arc::new(Mutex::new(vec![EF::zero(); total_len]));
24+
25+
// Process in parallel using Rayon
26+
mle.as_buffer()
2327
.par_chunks_exact(mle.strides()[0])
2428
.zip(partial_lagrange.as_buffer().par_iter())
25-
.map(|(chunk, scalar)| chunk.iter().map(|a| scalar.clone() * a.clone()).collect())
26-
.reduce(
27-
|| vec![EF::zero(); total_len],
28-
|mut a, b| {
29-
a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += b.clone());
30-
a
31-
},
32-
);
29+
.for_each(|(chunk, scalar)| {
30+
// Process each chunk with a thread-local accumulator
31+
let mut local_result = vec![EF::zero(); total_len];
32+
33+
// Avoid allocation in the inner loop
34+
for (i, a) in chunk.iter().enumerate() {
35+
if i < total_len {
36+
// Compute scalar * a directly into the accumulator
37+
local_result[i] = scalar.clone() * a.clone();
38+
}
39+
}
40+
41+
// Update the global result with our local computation
42+
let result_clone = Arc::clone(&result);
43+
let mut global_result = result_clone.lock().unwrap();
44+
for i in 0..total_len {
45+
global_result[i] += local_result[i].clone();
46+
}
47+
});
3348

34-
let dot_products = Buffer::from(dot_products);
35-
dst.storage = dot_products;
36-
dst
49+
// Create the final tensor
50+
let result_buffer = Buffer::from(Arc::try_unwrap(result).unwrap().into_inner().unwrap());
51+
Tensor { storage: result_buffer, dimensions }
3752
}
53+
54+
// Add a specialized implementation for the case when the number of polynomials is small
55+
pub(crate) fn eval_mle_at_point_small_batch<
56+
F: AbstractField + Sync,
57+
EF: AbstractExtensionField<F> + Send + Sync,
58+
>(
59+
mle: &Tensor<F, CpuBackend>,
60+
point: &Point<EF, CpuBackend>,
61+
) -> Tensor<EF, CpuBackend> {
62+
// For small batches (fewer than 4 polynomials), use a different approach
63+
// that avoids the overhead of parallelization
64+
let partial_lagrange = partial_lagrange_blocking(point);
65+
let mut sizes = mle.sizes().to_vec();
66+
sizes.remove(0);
67+
let dimensions = Dimensions::try_from(sizes).unwrap();
68+
let total_len = dimensions.total_len();
69+
70+
// Direct computation without parallelization for small batches
71+
let mut result = vec![EF::zero(); total_len];
72+
73+
for (chunk, scalar) in mle.as_buffer()
74+
.chunks_exact(mle.strides()[0])
75+
.zip(partial_lagrange.as_buffer().iter())
76+
{
77+
for (i, a) in chunk.iter().enumerate() {
78+
if i < total_len {
79+
result[i] += scalar.clone() * a.clone();
80+
}
81+
}
82+
}
83+
84+
let result_buffer = Buffer::from(result);
85+
Tensor { storage: result_buffer, dimensions }
86+
}

crates/multilinear/src/mle.rs

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use p3_field::{AbstractExtensionField, AbstractField, Field};
1212
use rand::{distributions::Standard, prelude::Distribution, Rng};
1313
use serde::{Deserialize, Serialize};
1414

15-
use crate::{eval::eval_mle_at_point_blocking, partial_lagrange_blocking, MleBaseBackend, Point};
15+
use crate::eval::{eval_mle_at_point_blocking, eval_mle_at_point_small_batch};
16+
use crate::{partial_lagrange_blocking, MleBaseBackend, Point};
1617

1718
/// A bacth of multi-linear polynomials.
1819
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -263,7 +264,15 @@ impl<T> Mle<T, CpuBackend> {
263264
T: AbstractField + 'static + Send + Sync,
264265
E: AbstractExtensionField<T> + 'static + Send + Sync,
265266
{
266-
MleEval::new(eval_mle_at_point_blocking(self.guts(), point))
267+
// Modify this line to use the optimized function when appropriate
268+
let result = if self.num_polynomials() < 4 {
269+
// For small batches, use the specialized implementation
270+
MleEval::new(eval_mle_at_point_small_batch(self.guts(), point))
271+
} else {
272+
// For larger batches, use the parallel implementation
273+
MleEval::new(eval_mle_at_point_blocking(self.guts(), point))
274+
};
275+
result
267276
}
268277

269278
pub fn blocking_partial_lagrange(point: &Point<T>) -> Mle<T, CpuBackend>
@@ -284,25 +293,43 @@ impl<T> Mle<T, CpuBackend> {
284293
///
285294
/// The polynomial f(X,Y) is an important building block in zerocheck and other protocols which use
286295
/// sumcheck.
287-
pub fn full_lagrange_eval<EF>(point_1: &Point<T>, point_2: &Point<EF>) -> EF
288-
where
289-
T: AbstractField,
290-
EF: AbstractExtensionField<T>,
291-
{
292-
assert_eq!(point_1.dimension(), point_2.dimension());
293-
294-
// Iterate over all values in the n-variates X and Y.
295-
point_1
296-
.iter()
297-
.zip(point_2.iter())
298-
.map(|(x, y)| {
299-
// Multiply by (x_i * y_i + (1-x_i) * (1-y_i)).
300-
let prod = y.clone() * x.clone();
301-
prod.clone() + prod + EF::one() - x.clone() - y.clone()
302-
})
303-
.product()
304-
}
305-
}
296+
pub fn full_lagrange_eval<EF>(point_1: &Point<T>, point_2: &Point<EF>) -> EF
297+
where
298+
T: AbstractField,
299+
EF: AbstractExtensionField<T>,
300+
{
301+
assert_eq!(point_1.dimension(), point_2.dimension());
302+
303+
// Iterate over all values in the n-variates X and Y.
304+
point_1
305+
.iter()
306+
.zip(point_2.iter())
307+
.map(|(x, y)| {
308+
// Multiply by (x_i * y_i + (1-x_i) * (1-y_i)).
309+
let prod = y.clone() * x.clone();
310+
prod.clone() + prod + EF::one() - x.clone() - y.clone()
311+
})
312+
.product()
313+
}
314+
}
315+
316+
317+
// pub fn blocking_eval_at<E>(&self, point: &Point<E>) -> MleEval<E>
318+
// where
319+
// T: AbstractField + 'static + Send + Sync,
320+
// E: AbstractExtensionField<T> + 'static + Send + Sync,
321+
// {
322+
// // Modify this line to use the optimized function when appropriate
323+
// let result = if self.num_polynomials() < 4 {
324+
// // For small batches, use the specialized implementation
325+
// MleEval::new(eval_mle_at_point_small_batch(self.guts(), point))
326+
// } else {
327+
// // For larger batches, use the parallel implementation
328+
// MleEval::new(eval_mle_at_point_blocking(self.guts(), point))
329+
// };
330+
// result
331+
// }
332+
// }
306333

307334
// impl<T: AbstractField + Send + Sync> TryInto<p3_matrix::dense::RowMajorMatrix<T>>
308335
// for Mle<T, CpuBackend>
@@ -492,3 +519,4 @@ impl<T> FromIterator<T> for MleEval<T, CpuBackend> {
492519
Self::new(Tensor::from(iter.into_iter().collect::<Vec<_>>()))
493520
}
494521
}
522+

0 commit comments

Comments
 (0)