Skip to content

Commit cb50067

Browse files
authored
feat!: add MultiplyExpr (#18)
1 parent 88f3d40 commit cb50067

12 files changed

+666
-37
lines changed

crates/proof-of-sql/src/base/database/column.rs

+17
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,23 @@ impl<'a, S: Scalar> Column<'a, S> {
131131
}
132132
}
133133

134+
/// Returns element at index as scalar
135+
///
136+
/// Note that if index is out of bounds, this function will return None
137+
pub(crate) fn scalar_at(&self, index: usize) -> Option<S> {
138+
(index < self.len()).then_some(match self {
139+
Self::Boolean(col) => S::from(col[index]),
140+
Self::SmallInt(col) => S::from(col[index]),
141+
Self::Int(col) => S::from(col[index]),
142+
Self::BigInt(col) => S::from(col[index]),
143+
Self::Int128(col) => S::from(col[index]),
144+
Self::Scalar(col) => col[index],
145+
Self::Decimal75(_, _, col) => col[index],
146+
Self::VarChar((_, scals)) => scals[index],
147+
Self::TimestampTZ(_, _, col) => S::from(col[index]),
148+
})
149+
}
150+
134151
/// Convert a column to a vector of Scalar values with scaling
135152
pub(crate) fn to_scalar_with_scaling(&self, scale: i8) -> Vec<S> {
136153
let scale_factor = scale_scalar(S::ONE, scale).expect("Invalid scale factor");

crates/proof-of-sql/src/base/math/decimal.rs

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ pub enum DecimalError {
3030
/// or non-positive aka InvalidPrecision
3131
InvalidPrecision(String),
3232

33+
#[error("Decimal scale is not valid: {0}")]
34+
/// Decimal scale is not valid. Here we use i16 in order to include
35+
/// invalid scale values
36+
InvalidScale(i16),
37+
3338
#[error("Unsupported operation: cannot round decimal: {0}")]
3439
/// This error occurs when attempting to scale a
3540
/// decimal in such a way that a loss of precision occurs.

crates/proof-of-sql/src/sql/ast/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ pub(crate) use add_subtract_expr::AddSubtractExpr;
77
#[cfg(all(test, feature = "blitzar"))]
88
mod add_subtract_expr_test;
99

10+
mod multiply_expr;
11+
use multiply_expr::MultiplyExpr;
12+
#[cfg(all(test, feature = "blitzar"))]
13+
mod multiply_expr_test;
14+
1015
mod filter_expr;
1116
pub(crate) use filter_expr::FilterExpr;
1217
#[cfg(test)]
@@ -59,7 +64,8 @@ pub(crate) use comparison_util::scale_and_subtract;
5964

6065
mod numerical_util;
6166
pub(crate) use numerical_util::{
62-
add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types,
67+
add_subtract_columns, multiply_columns, scale_and_add_subtract_eval,
68+
try_add_subtract_column_types, try_multiply_column_types,
6369
};
6470

6571
mod equals_expr;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
use super::{ProvableExpr, ProvableExprPlan};
2+
use crate::{
3+
base::{
4+
commitment::Commitment,
5+
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
6+
proof::ProofError,
7+
},
8+
sql::{
9+
ast::{multiply_columns, try_multiply_column_types},
10+
proof::{CountBuilder, ProofBuilder, SumcheckSubpolynomialType, VerificationBuilder},
11+
},
12+
};
13+
use bumpalo::Bump;
14+
use num_traits::One;
15+
use serde::{Deserialize, Serialize};
16+
use std::collections::HashSet;
17+
18+
/// Provable numerical * expression
19+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20+
pub struct MultiplyExpr<C: Commitment> {
21+
lhs: Box<ProvableExprPlan<C>>,
22+
rhs: Box<ProvableExprPlan<C>>,
23+
}
24+
25+
impl<C: Commitment> MultiplyExpr<C> {
26+
/// Create numerical `*` expression
27+
pub fn new(lhs: Box<ProvableExprPlan<C>>, rhs: Box<ProvableExprPlan<C>>) -> Self {
28+
Self { lhs, rhs }
29+
}
30+
}
31+
32+
impl<C: Commitment> ProvableExpr<C> for MultiplyExpr<C> {
33+
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
34+
self.lhs.count(builder)?;
35+
self.rhs.count(builder)?;
36+
builder.count_subpolynomials(1);
37+
builder.count_intermediate_mles(1);
38+
builder.count_degree(3);
39+
Ok(())
40+
}
41+
42+
fn data_type(&self) -> ColumnType {
43+
try_multiply_column_types(self.lhs.data_type(), self.rhs.data_type())
44+
.expect("Failed to multiply column types")
45+
}
46+
47+
fn result_evaluate<'a>(
48+
&self,
49+
table_length: usize,
50+
alloc: &'a Bump,
51+
accessor: &'a dyn DataAccessor<C::Scalar>,
52+
) -> Column<'a, C::Scalar> {
53+
let lhs_column: Column<'a, C::Scalar> =
54+
self.lhs.result_evaluate(table_length, alloc, accessor);
55+
let rhs_column: Column<'a, C::Scalar> =
56+
self.rhs.result_evaluate(table_length, alloc, accessor);
57+
let scalars = multiply_columns(&lhs_column, &rhs_column, alloc);
58+
Column::Scalar(scalars)
59+
}
60+
61+
#[tracing::instrument(
62+
name = "proofs.sql.ast.and_expr.prover_evaluate",
63+
level = "info",
64+
skip_all
65+
)]
66+
fn prover_evaluate<'a>(
67+
&self,
68+
builder: &mut ProofBuilder<'a, C::Scalar>,
69+
alloc: &'a Bump,
70+
accessor: &'a dyn DataAccessor<C::Scalar>,
71+
) -> Column<'a, C::Scalar> {
72+
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
73+
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
74+
75+
// lhs_times_rhs
76+
let lhs_times_rhs: &'a [C::Scalar] = multiply_columns(&lhs_column, &rhs_column, alloc);
77+
builder.produce_intermediate_mle(lhs_times_rhs);
78+
79+
// subpolynomial: lhs_times_rhs - lhs * rhs
80+
builder.produce_sumcheck_subpolynomial(
81+
SumcheckSubpolynomialType::Identity,
82+
vec![
83+
(C::Scalar::one(), vec![Box::new(lhs_times_rhs)]),
84+
(
85+
-C::Scalar::one(),
86+
vec![Box::new(lhs_column), Box::new(rhs_column)],
87+
),
88+
],
89+
);
90+
Column::Scalar(lhs_times_rhs)
91+
}
92+
93+
fn verifier_evaluate(
94+
&self,
95+
builder: &mut VerificationBuilder<C>,
96+
accessor: &dyn CommitmentAccessor<C>,
97+
) -> Result<C::Scalar, ProofError> {
98+
let lhs = self.lhs.verifier_evaluate(builder, accessor)?;
99+
let rhs = self.rhs.verifier_evaluate(builder, accessor)?;
100+
101+
// lhs_times_rhs
102+
let lhs_times_rhs = builder.consume_intermediate_mle();
103+
104+
// subpolynomial: lhs_times_rhs - lhs * rhs
105+
let eval = builder.mle_evaluations.random_evaluation * (lhs_times_rhs - lhs * rhs);
106+
builder.produce_sumcheck_subpolynomial_evaluation(&eval);
107+
108+
// selection
109+
Ok(lhs_times_rhs)
110+
}
111+
112+
fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
113+
self.lhs.get_column_references(columns);
114+
self.rhs.get_column_references(columns);
115+
}
116+
}

0 commit comments

Comments
 (0)