Skip to content

Commit e005c43

Browse files
perf: add parallel implementation to the numerical_utils module's multiply_columns function (#1002)
# Rationale for this change This PR adds a parallel implementation to the `numerical_utils` module's `multiply_columns` function. Various benchmarks have improved with this change. For example, the Sum Count benchmark query had a 1.42x performance improvement on the Multi-A100 VM with a table size of `1,000,000`. The `MultiplyExpr::final_round_evaluate` saw a 9.42x performance improvement. Before 2.39s <img width="1878" height="341" alt="image" src="https://github.com/user-attachments/assets/b9305381-d162-4f2d-8024-221d96e4c0b0" /> After 1.72s <img width="1878" height="341" alt="image" src="https://github.com/user-attachments/assets/5c128848-685d-494d-a0ff-bff30988ab91" /> This is a first step. Further improvements will investigate doing multiplication inside of the Column module to avoid conversions during each multiplication. Note: creating an `unsafe_scalar_at` function that avoid returning an Option did not show any performance gains when benchmarked. # What changes are included in this PR? - A Rayon implementation is added to the the `numerical_utils` module's `multiply_columns` function # Are these changes tested? Yes
1 parent 6fc4348 commit e005c43

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::base::{
22
database::{try_cast_types, try_scale_cast_types, Column, ColumnOperationResult, ColumnType},
3+
if_rayon,
34
math::decimal::Precision,
45
scalar::{Scalar, ScalarExt},
56
};
@@ -9,6 +10,8 @@ use bumpalo::Bump;
910
use core::{convert::TryInto, ops::Neg};
1011
use itertools::izip;
1112
use num_traits::{NumCast, PrimInt};
13+
#[cfg(feature = "rayon")]
14+
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
1215

1316
/// Add or subtract two columns together.
1417
#[tracing::instrument(level = "debug", skip_all)]
@@ -51,9 +54,17 @@ pub(crate) fn multiply_columns<'a, S: Scalar>(
5154
lhs_len == rhs_len,
5255
"lhs and rhs should have the same length"
5356
);
54-
alloc.alloc_slice_fill_with(lhs_len, |i| {
55-
lhs.scalar_at(i).unwrap() * rhs.scalar_at(i).unwrap()
56-
})
57+
if_rayon!(
58+
{
59+
let result = alloc.alloc_slice_fill_with(lhs_len, |_| S::ZERO);
60+
result.par_iter_mut().enumerate().for_each(|(i, val)| {
61+
*val = lhs.scalar_at(i).unwrap() * rhs.scalar_at(i).unwrap();
62+
});
63+
result
64+
},
65+
alloc.alloc_slice_fill_with(lhs_len, |i| lhs.scalar_at(i).unwrap()
66+
* rhs.scalar_at(i).unwrap())
67+
)
5768
}
5869

5970
/// Divides two columns of data, where the data types are some signed int type(s).

0 commit comments

Comments
 (0)