Skip to content

Commit 52313d7

Browse files
authored
fix: ensure reasonable types for numeric ops (#733)
Please be sure to look over the pull request guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr. # Please go through the following checklist - [x] The PR title and commit messages adhere to guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md. In particular `!` is used if and only if at least one breaking change has been introduced. - [x] I have run the ci check script with `source scripts/run_ci_checks.sh`. - [x] I have run the clean commit check script with `source scripts/check_commits.sh`, and the commit history is certified to follow clean commit guidelines as described here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/COMMIT_GUIDELINES.md - [x] The latest changes from `main` have been incorporated to this PR by simple rebase if possible, if not, then conflicts are resolved appropriately. # Rationale for this change Currently `ProofExpr` of numeric types require scale casting as well as explicit passing of scales since arithmetic ops on `ProofExpr` always return `Scalar` type columns. In order to port the code to Solidity we need to simplify the logic. <!-- Why are you proposing this change? If this is already explained clearly in the linked issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. Example: Add `NestedLoopJoinExec`. Closes #345. Since we added `HashJoinExec` in #323 it has been possible to do provable inner joins. However performance is not satisfactory in some cases. Hence we need to fix the problem by implement `NestedLoopJoinExec` and speed up the code for `HashJoinExec`. --> # What changes are included in this PR? <!-- There is no need to duplicate the description in the ticket here but it is sometimes worth providing a summary of the individual changes in this PR. Example: - Add `NestedLoopJoinExec`. - Speed up `HashJoinExec`. - Route joins to `NestedLoopJoinExec` if the outer input is sufficiently small. --> 1. Make sure all arithmetic `ProofExpr` have decimal types and return decimal columns to ensure consistency && prevent overflows from being a thing 2. Cap precision to 75 and wrap around as opposed to banning ops that cause precision to be too large # Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 3. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? Example: Yes. --> Yes.
2 parents 0e5b69a + f14fd2e commit 52313d7

File tree

18 files changed

+206
-522
lines changed

18 files changed

+206
-522
lines changed

crates/proof-of-sql-planner/tests/e2e_tests.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ fn test_simple_filter_queries() {
183183
let alloc = Bump::new();
184184
let sql = "select id, name from cats where age > 2;
185185
select * from cats;
186-
select name == $1 as name_eq from cats;";
186+
select name == $1 as name_eq from cats;
187+
select 2 * age as double_age from cats";
187188
let tables: IndexMap<TableRef, Table<DoryScalar>> = indexmap! {
188189
TableRef::from_names(None, "cats") => table(
189190
vec![
@@ -204,6 +205,7 @@ fn test_simple_filter_queries() {
204205
tinyint("age", [13_i8, 2, 0, 4, 4]),
205206
]),
206207
owned_table([boolean("name_eq", [false, false, true, false, false])]),
208+
owned_table([decimal75("double_age", 39, 0, [26_i8, 4, 0, 8, 8])]),
207209
];
208210

209211
// Create public parameters for DynamicDoryEvaluationProof
@@ -449,14 +451,14 @@ fn test_coin() {
449451
vec![
450452
borrowed_varchar("from_address", ["0x1", "0x2", "0x3", "0x2", "0x1"], &alloc),
451453
borrowed_varchar("to_address", ["0x2", "0x3", "0x1", "0x3", "0x2"], &alloc),
452-
borrowed_decimal75("value", 20, 0, [100, 200, 300, 400, 500], &alloc),
454+
borrowed_decimal75("value", 75, 0, [100, 200, 300, 400, 500], &alloc),
453455
borrowed_timestamptz("timestamp", PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), [1, 2, 3, 4, 4], &alloc),
454456
]
455457
)
456458
};
457459
let expected_results: Vec<OwnedTable<DoryScalar>> = vec![owned_table([
458-
decimal75("weighted_value", 62, 0, [100]),
459-
decimal75("total_balance", 41, 0, [0]),
460+
decimal75("weighted_value", 75, 0, [100]),
461+
decimal75("total_balance", 75, 0, [0]),
460462
bigint("num_transactions", [5_i64]),
461463
])];
462464

crates/proof-of-sql/src/base/database/column_type_operation.rs

Lines changed: 67 additions & 122 deletions
Large diffs are not rendered by default.

crates/proof-of-sql/src/sql/parse/query_context_builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ mod tests {
315315
base::{
316316
database::{ColumnType, TableRef},
317317
map::indexmap,
318+
math::decimal::Precision,
318319
},
319320
sql::parse::query_expr_tests::schema_accessor_from_table_ref_with_schema,
320321
};
@@ -349,6 +350,6 @@ mod tests {
349350
&Expression::Column(Identifier::try_new("b").unwrap()),
350351
)
351352
.unwrap();
352-
assert_eq!(res, ColumnType::BigInt);
353+
assert_eq!(res, ColumnType::Decimal75(Precision::new(20).unwrap(), 0));
353354
}
354355
}

crates/proof-of-sql/src/sql/proof_exprs/add_expr.rs

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use super::{add_subtract_columns, scale_and_add_subtract_eval, DynProofExpr, ProofExpr};
1+
use super::{
2+
add_subtract_columns, scale_and_add_subtract_eval, DecimalProofExpr, DynProofExpr, ProofExpr,
3+
};
24
use crate::{
35
base::{
46
database::{
@@ -43,14 +45,8 @@ impl ProofExpr for AddExpr {
4345
) -> PlaceholderResult<Column<'a, S>> {
4446
let lhs_column: Column<'a, S> = self.lhs.first_round_evaluate(alloc, table, params)?;
4547
let rhs_column: Column<'a, S> = self.rhs.first_round_evaluate(alloc, table, params)?;
46-
Ok(Column::Scalar(add_subtract_columns(
47-
lhs_column,
48-
rhs_column,
49-
self.lhs.data_type().scale().unwrap_or(0),
50-
self.rhs.data_type().scale().unwrap_or(0),
51-
alloc,
52-
false,
53-
)))
48+
let res = add_subtract_columns(lhs_column, rhs_column, alloc, false);
49+
Ok(Column::Decimal75(self.precision(), self.scale(), res))
5450
}
5551

5652
#[tracing::instrument(
@@ -73,18 +69,10 @@ impl ProofExpr for AddExpr {
7369
let rhs_column: Column<'a, S> = self
7470
.rhs
7571
.final_round_evaluate(builder, alloc, table, params)?;
76-
let res = Column::Scalar(add_subtract_columns(
77-
lhs_column,
78-
rhs_column,
79-
self.lhs.data_type().scale().unwrap_or(0),
80-
self.rhs.data_type().scale().unwrap_or(0),
81-
alloc,
82-
false,
83-
));
84-
72+
let res = add_subtract_columns(lhs_column, rhs_column, alloc, false);
8573
log::log_memory_usage("End");
8674

87-
Ok(res)
75+
Ok(Column::Decimal75(self.precision(), self.scale(), res))
8876
}
8977

9078
fn verifier_evaluate<S: Scalar>(
@@ -111,3 +99,5 @@ impl ProofExpr for AddExpr {
11199
self.rhs.get_column_references(columns);
112100
}
113101
}
102+
103+
impl DecimalProofExpr for AddExpr {}

crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr_test.rs

Lines changed: 10 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@ use crate::{
22
base::{
33
commitment::InnerProductProof,
44
database::{
5-
owned_table_utility::*, table_utility::*, Column, OwnedTableTestAccessor, TableRef,
5+
owned_table_utility::*, table_utility::*, OwnedTableTestAccessor, TableRef,
66
TableTestAccessor,
77
},
8-
scalar::test_scalar::TestScalar,
98
},
109
proof_primitive::inner_product::curve_25519_scalar::Curve25519Scalar,
1110
sql::{
12-
proof::{exercise_verification, QueryError, VerifiableQueryResult},
11+
proof::{exercise_verification, VerifiableQueryResult},
1312
proof_exprs::{test_utility::*, DynProofExpr, ProofExpr},
14-
proof_plans::{test_utility::*, DynProofPlan},
15-
AnalyzeError,
13+
proof_plans::test_utility::*,
1614
},
1715
};
1816
use bumpalo::Bump;
@@ -57,7 +55,7 @@ fn we_can_prove_a_typical_add_subtract_query() {
5755
let expected_res = owned_table([
5856
smallint("a", [3_i16, 4]),
5957
bigint("c", [2_i16, 0]),
60-
bigint("res", [4_i64, 5]),
58+
decimal75("res", 20, 0, [4_i64, 5]),
6159
varchar("d", ["efg", "g"]),
6260
]);
6361
assert_eq!(res, expected_res);
@@ -110,133 +108,6 @@ fn we_can_prove_a_typical_add_subtract_query_with_decimals() {
110108
assert_eq!(res, expected_res);
111109
}
112110

113-
// Column type issue tests
114-
#[test]
115-
fn decimal_column_type_issues_error_out_when_producing_provable_ast() {
116-
let data = owned_table([decimal75("a", 75, 2, [1_i16, 2, 3, 4])]);
117-
let t = TableRef::new("sxt", "t");
118-
let accessor =
119-
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
120-
assert!(matches!(
121-
DynProofExpr::try_new_add(column(&t, "a", &accessor), const_bigint(1)),
122-
Err(AnalyzeError::DataTypeMismatch { .. })
123-
));
124-
}
125-
126-
// Overflow tests
127-
// select a + b as c from sxt.t where b = 1
128-
#[test]
129-
fn result_expr_can_overflow() {
130-
let data = owned_table([
131-
smallint("a", [i16::MAX, i16::MIN]),
132-
smallint("b", [1_i16, 0]),
133-
]);
134-
let t = TableRef::new("sxt", "t");
135-
let accessor =
136-
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
137-
let ast: DynProofPlan = filter(
138-
vec![aliased_plan(
139-
add(column(&t, "a", &accessor), column(&t, "b", &accessor)),
140-
"c",
141-
)],
142-
tab(&t),
143-
equal(column(&t, "b", &accessor), const_bigint(1)),
144-
);
145-
let verifiable_res: VerifiableQueryResult<InnerProductProof> =
146-
VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
147-
assert!(matches!(
148-
verifiable_res.verify(&ast, &accessor, &(), &[]),
149-
Err(QueryError::Overflow)
150-
));
151-
}
152-
153-
// select a + b as c from sxt.t where b == 0
154-
#[test]
155-
fn overflow_in_nonselected_rows_doesnt_error_out() {
156-
let data = owned_table([
157-
smallint("a", [i16::MAX, i16::MIN + 1]),
158-
smallint("b", [1_i16, 0]),
159-
]);
160-
let t = TableRef::new("sxt", "t");
161-
let accessor =
162-
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
163-
let ast: DynProofPlan = filter(
164-
vec![aliased_plan(
165-
add(column(&t, "a", &accessor), column(&t, "b", &accessor)),
166-
"c",
167-
)],
168-
tab(&t),
169-
equal(column(&t, "b", &accessor), const_bigint(0)),
170-
);
171-
let verifiable_res: VerifiableQueryResult<InnerProductProof> =
172-
VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
173-
exercise_verification(&verifiable_res, &ast, &accessor, &t);
174-
let res = verifiable_res
175-
.verify(&ast, &accessor, &(), &[])
176-
.unwrap()
177-
.table;
178-
let expected_res = owned_table([smallint("c", [i16::MIN + 1])]);
179-
assert_eq!(res, expected_res);
180-
}
181-
182-
// select a, b from sxt.t where a + b >= 0
183-
#[test]
184-
fn overflow_in_where_clause_doesnt_error_out() {
185-
let data = owned_table([bigint("a", [i64::MAX, i64::MIN]), smallint("b", [1_i16, 0])]);
186-
let t = TableRef::new("sxt", "t");
187-
let accessor =
188-
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
189-
let ast: DynProofPlan = filter(
190-
cols_expr_plan(&t, &["a", "b"], &accessor),
191-
tab(&t),
192-
gte(
193-
add(column(&t, "a", &accessor), column(&t, "b", &accessor)),
194-
const_bigint(0),
195-
),
196-
);
197-
let verifiable_res: VerifiableQueryResult<InnerProductProof> =
198-
VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
199-
exercise_verification(&verifiable_res, &ast, &accessor, &t);
200-
let res = verifiable_res
201-
.verify(&ast, &accessor, &(), &[])
202-
.unwrap()
203-
.table;
204-
let expected_res = owned_table([bigint("a", [i64::MAX]), smallint("b", [1_i16])]);
205-
assert_eq!(res, expected_res);
206-
}
207-
208-
// select a + b as c, a - b as d from sxt.t
209-
#[test]
210-
fn result_expr_can_overflow_more() {
211-
let data = owned_table([
212-
bigint("a", [i64::MAX, i64::MIN, i64::MAX, i64::MIN]),
213-
bigint("b", [i64::MAX, i64::MAX, i64::MIN, i64::MIN]),
214-
]);
215-
let t = TableRef::new("sxt", "t");
216-
let accessor =
217-
OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t.clone(), data, 0, ());
218-
let ast: DynProofPlan = filter(
219-
vec![
220-
aliased_plan(
221-
add(column(&t, "a", &accessor), column(&t, "b", &accessor)),
222-
"c",
223-
),
224-
aliased_plan(
225-
subtract(column(&t, "a", &accessor), column(&t, "b", &accessor)),
226-
"d",
227-
),
228-
],
229-
tab(&t),
230-
const_bool(true),
231-
);
232-
let verifiable_res: VerifiableQueryResult<InnerProductProof> =
233-
VerifiableQueryResult::new(&ast, &accessor, &(), &[]).unwrap();
234-
assert!(matches!(
235-
verifiable_res.verify(&ast, &accessor, &(), &[]),
236-
Err(QueryError::Overflow)
237-
));
238-
}
239-
240111
fn test_random_tables_with_given_offset(offset: usize) {
241112
let dist = Uniform::new(-3, 4);
242113
let mut rng = StdRng::from_seed([0u8; 32]);
@@ -283,11 +154,11 @@ fn test_random_tables_with_given_offset(offset: usize) {
283154
and(
284155
equal(
285156
column(&t, "b", &accessor),
286-
const_scalar::<TestScalar, _>(filter_val1.as_str()),
157+
const_scalar::<Curve25519Scalar, _>(filter_val1.as_str()),
287158
),
288159
equal(
289160
column(&t, "c", &accessor),
290-
const_scalar::<TestScalar, _>(filter_val2),
161+
const_scalar::<Curve25519Scalar, _>(filter_val2),
291162
),
292163
),
293164
);
@@ -307,13 +178,14 @@ fn test_random_tables_with_given_offset(offset: usize) {
307178
))
308179
.filter_map(|(a, b, c, d)| {
309180
if b == &filter_val1 && c == &filter_val2 {
310-
Some((i128::from(*a + *c - 4), d.clone()))
181+
Some((Curve25519Scalar::from(*a + *c - 4), d.clone()))
311182
} else {
312183
None
313184
}
314185
})
315186
.multiunzip();
316-
let expected_result = owned_table([varchar("d", expected_d), int128("f", expected_f)]);
187+
let expected_result =
188+
owned_table([varchar("d", expected_d), decimal75("f", 40, 0, expected_f)]);
317189

318190
assert_eq!(expected_result, res);
319191
}
@@ -349,10 +221,6 @@ fn we_can_compute_the_correct_output_of_an_add_subtract_expr_using_first_round_e
349221
let res = add_subtract_expr
350222
.first_round_evaluate(&alloc, &data, &[])
351223
.unwrap();
352-
let expected_res_scalar = [0, 2, 2, 4]
353-
.iter()
354-
.map(|v| Curve25519Scalar::from(*v))
355-
.collect::<Vec<_>>();
356-
let expected_res = Column::Scalar(&expected_res_scalar);
224+
let expected_res = borrowed_decimal75("res", 21, 0, [0_i64, 2, 2, 4], &alloc).1;
357225
assert_eq!(res, expected_res);
358226
}

crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
2323
alloc: &'a Bump,
2424
lhs: Column<'a, S>,
2525
rhs: Column<'a, S>,
26-
lhs_scale: i8,
27-
rhs_scale: i8,
2826
is_equal: bool,
2927
) -> AnalyzeResult<&'a [S]> {
3028
let lhs_len = lhs.len();
@@ -48,6 +46,8 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
4846
right_type: rhs_type.to_string(),
4947
});
5048
}
49+
let lhs_scale = lhs_type.scale().unwrap_or(0);
50+
let rhs_scale = rhs_type.scale().unwrap_or(0);
5151
let max_scale = max(lhs_scale, rhs_scale);
5252
let lhs_upscale = max_scale - lhs_scale;
5353
let rhs_upscale = max_scale - rhs_scale;

crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ impl ProofExpr for EqualsExpr {
4444

4545
let lhs_column = self.lhs.first_round_evaluate(alloc, table, params)?;
4646
let rhs_column = self.rhs.first_round_evaluate(alloc, table, params)?;
47-
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
48-
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
49-
let res = scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, true)
47+
let res = scale_and_subtract(alloc, lhs_column, rhs_column, true)
5048
.expect("Failed to scale and subtract");
5149
let res = Column::Boolean(first_round_evaluate_equals_zero(
5250
table.num_rows(),
@@ -75,11 +73,8 @@ impl ProofExpr for EqualsExpr {
7573
let rhs_column = self
7674
.rhs
7775
.final_round_evaluate(builder, alloc, table, params)?;
78-
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
79-
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
80-
let scale_and_subtract_res =
81-
scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, true)
82-
.expect("Failed to scale and subtract");
76+
let scale_and_subtract_res = scale_and_subtract(alloc, lhs_column, rhs_column, true)
77+
.expect("Failed to scale and subtract");
8378
let res = Column::Boolean(final_round_evaluate_equals_zero(
8479
table.num_rows(),
8580
builder,

crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,12 @@ impl ProofExpr for InequalityExpr {
5353

5454
let lhs_column = self.lhs.first_round_evaluate(alloc, table, params)?;
5555
let rhs_column = self.rhs.first_round_evaluate(alloc, table, params)?;
56-
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
57-
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
5856
let table_length = table.num_rows();
5957
let diff = if self.is_lt {
60-
scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false)
58+
scale_and_subtract(alloc, lhs_column, rhs_column, false)
6159
.expect("Failed to scale and subtract")
6260
} else {
63-
scale_and_subtract(alloc, rhs_column, lhs_column, rhs_scale, lhs_scale, false)
61+
scale_and_subtract(alloc, rhs_column, lhs_column, false)
6462
.expect("Failed to scale and subtract")
6563
};
6664

@@ -92,13 +90,11 @@ impl ProofExpr for InequalityExpr {
9290
let rhs_column = self
9391
.rhs
9492
.final_round_evaluate(builder, alloc, table, params)?;
95-
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
96-
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
9793
let diff = if self.is_lt {
98-
scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, false)
94+
scale_and_subtract(alloc, lhs_column, rhs_column, false)
9995
.expect("Failed to scale and subtract")
10096
} else {
101-
scale_and_subtract(alloc, rhs_column, lhs_column, rhs_scale, lhs_scale, false)
97+
scale_and_subtract(alloc, rhs_column, lhs_column, false)
10298
.expect("Failed to scale and subtract")
10399
};
104100

crates/proof-of-sql/src/sql/proof_exprs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! This module proves provable expressions.
22
mod proof_expr;
3+
pub(crate) use proof_expr::DecimalProofExpr;
34
pub use proof_expr::ProofExpr;
45
#[cfg(all(test, feature = "blitzar"))]
56
mod proof_expr_test;

0 commit comments

Comments
 (0)