Skip to content

Commit a671bbc

Browse files
bfv: optimize dot_product_scalar by avoiding redundant iterator clones
The `dot_product_scalar` function previously cloned and traversed the input iterators multiple times for validation and counting, which could be inefficient for complex iterators. This change collects the iterators into `Vec`s of references upfront, ensuring only a single pass over the input iterators and allowing efficient repeated access via slice iteration. Benchmarks show up to 10% performance improvement for large degree parameters (N=16384). While there is minor overhead for simple iterators due to allocation, this approach guarantees performance stability for arbitrary input iterators. Co-authored-by: tlepoint <1345502+tlepoint@users.noreply.github.com>
1 parent 8d67113 commit a671bbc

1 file changed

Lines changed: 25 additions & 7 deletions

File tree

crates/fhe/src/bfv/ops/dot_product.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,28 @@ where
5858
I: Iterator<Item = &'a Ciphertext> + Clone,
5959
J: Iterator<Item = &'a Plaintext> + Clone,
6060
{
61-
let count = min(ct.clone().count(), pt.clone().count());
61+
let ct_vec: Vec<&'a Ciphertext> = ct.collect();
62+
let pt_vec: Vec<&'a Plaintext> = pt.collect();
63+
64+
let count = min(ct_vec.len(), pt_vec.len());
6265
if count == 0 {
6366
return Err(Error::DefaultError(
6467
"At least one iterator is empty".to_string(),
6568
));
6669
}
67-
let ct_first = ct.clone().next().unwrap();
70+
let ct_first = ct_vec[0];
6871
let ctx = ct_first[0].ctx();
6972

70-
if izip!(ct.clone(), pt.clone()).any(|(cti, pti)| {
73+
if izip!(
74+
ct_vec.iter().cloned().take(count),
75+
pt_vec.iter().cloned().take(count)
76+
)
77+
.any(|(cti, pti)| {
7178
cti.par != ct_first.par || pti.par != ct_first.par || cti.len() != ct_first.len()
7279
}) {
7380
return Err(Error::DefaultError("Mismatched parameters".to_string()));
7481
}
75-
if ct.clone().any(|cti| cti.len() != ct_first.len()) {
82+
if ct_vec.iter().cloned().any(|cti| cti.len() != ct_first.len()) {
7683
return Err(Error::DefaultError(
7784
"Mismatched number of parts in the ciphertexts".to_string(),
7885
));
@@ -91,8 +98,16 @@ where
9198
let c = (0..ct_first.len())
9299
.map(|i| {
93100
poly_dot_product(
94-
ct.clone().map(|cti| unsafe { cti.get_unchecked(i) }),
95-
pt.clone().map(|pti| &pti.poly_ntt),
101+
ct_vec
102+
.iter()
103+
.cloned()
104+
.take(count)
105+
.map(|cti| unsafe { cti.get_unchecked(i) }),
106+
pt_vec
107+
.iter()
108+
.cloned()
109+
.take(count)
110+
.map(|pti| &pti.poly_ntt),
96111
)
97112
.map_err(Error::MathError)
98113
})
@@ -106,7 +121,10 @@ where
106121
})
107122
} else {
108123
let mut acc = Array::zeros((ct_first.len(), ctx.moduli().len(), ct_first.par.degree()));
109-
for (ciphertext, plaintext) in izip!(ct, pt) {
124+
for (ciphertext, plaintext) in izip!(
125+
ct_vec.iter().cloned().take(count),
126+
pt_vec.iter().cloned().take(count)
127+
) {
110128
let pt_coefficients = plaintext.poly_ntt.coefficients();
111129
for (mut acci, ci) in izip!(acc.outer_iter_mut(), ciphertext.iter()) {
112130
let ci_coefficients = ci.coefficients();

0 commit comments

Comments
 (0)