Skip to content

Commit 6fc9b53

Browse files
refactor: account for count alias in roundtrip (#1096)
Please be sure to look over the pull request guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr. # Please go through the following checklist - [ ] The PR title and commit messages adhere to guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md. In particular `!` is used if and only if at least one breaking change has been introduced. - [ ] I have run the ci check script with `source scripts/run_ci_checks.sh`. - [ ] I have run the clean commit check script with `source scripts/check_commits.sh`, and the commit history is certified to follow clean commit guidelines as described here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/COMMIT_GUIDELINES.md - [ ] The latest changes from `main` have been incorporated to this PR by simple rebase if possible, if not, then conflicts are resolved appropriately. # Rationale for this change We need to ensure, for now, that the count alias is correct, so that roundtrip serialization is the identity function. # What changes are included in this PR? Enforcing the count alias to be correct on the deserialized plan. # Are these changes tested? Yes.
2 parents afd0d85 + b98a066 commit 6fc9b53

File tree

1 file changed

+68
-2
lines changed
  • crates/proof-of-sql/src/sql/evm_proof_plan

1 file changed

+68
-2
lines changed

crates/proof-of-sql/src/sql/evm_proof_plan/plans.rs

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use alloc::{
1717
string::{String, ToString},
1818
vec::Vec,
1919
};
20+
use core::iter;
2021
use serde::{Deserialize, Serialize};
2122
use sqlparser::ast::Ident;
2223

@@ -215,6 +216,28 @@ fn try_unwrap_output_column_names(
215216
Ok(output_column_names)
216217
}
217218

219+
fn try_unwrap_output_column_names_with_count_alias(
220+
output_column_names: Option<&IndexSet<String>>,
221+
length: usize,
222+
count_alias: &String,
223+
) -> EVMProofPlanResult<IndexSet<String>> {
224+
let output_column_names = match output_column_names {
225+
Some(output_column_names) => {
226+
if length > output_column_names.len() {
227+
return Err(EVMProofPlanError::InvalidOutputColumnName);
228+
}
229+
output_column_names.clone()
230+
}
231+
None => (0..length)
232+
.map(|i| i.to_string())
233+
.filter(|name| name != count_alias)
234+
.take(length - 1)
235+
.chain(iter::once(count_alias.clone()))
236+
.collect::<IndexSet<_>>(),
237+
};
238+
Ok(output_column_names)
239+
}
240+
218241
/// Represents a filter execution plan in EVM.
219242
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
220243
pub(crate) struct EVMLegacyFilterExec {
@@ -618,8 +641,11 @@ impl EVMAggregateExec {
618641
output_column_names: Option<&IndexSet<String>>,
619642
) -> EVMProofPlanResult<AggregateExec> {
620643
let required_alias_count = self.group_by_exprs.len() + self.sum_expr.len() + 1;
621-
let output_column_names =
622-
try_unwrap_output_column_names(output_column_names, required_alias_count)?;
644+
let output_column_names = try_unwrap_output_column_names_with_count_alias(
645+
output_column_names,
646+
required_alias_count,
647+
&self.count_alias_name,
648+
)?;
623649
let input = self
624650
.input_plan
625651
.try_into_proof_plan(table_refs, column_refs, None)?;
@@ -805,6 +831,7 @@ mod tests {
805831
proof_plans::{DynProofPlan, SortMergeJoinExec},
806832
},
807833
};
834+
use indexmap::IndexSet;
808835

809836
#[test]
810837
fn we_can_put_projection_exec_in_evm() {
@@ -2455,4 +2482,43 @@ mod tests {
24552482
Err(EVMProofPlanError::InvalidOutputColumnName)
24562483
));
24572484
}
2485+
2486+
#[test]
2487+
fn we_can_unwrap_correct_output_column_names_when_none() {
2488+
let output_column_names =
2489+
try_unwrap_output_column_names_with_count_alias(None, 2, &"0".to_string()).unwrap();
2490+
let expected_output_column_names: IndexSet<
2491+
String,
2492+
core::hash::BuildHasherDefault<ahash::AHasher>,
2493+
> = vec!["1".to_string(), "0".to_string()].into_iter().collect();
2494+
assert_eq!(output_column_names, expected_output_column_names);
2495+
}
2496+
2497+
#[test]
2498+
fn we_can_unwrap_correct_output_column_names_when_some() {
2499+
let expected_output_column_names: IndexSet<String, _> =
2500+
vec!["a".to_string(), "b".to_string()].into_iter().collect();
2501+
let output_column_names = try_unwrap_output_column_names_with_count_alias(
2502+
Some(&expected_output_column_names),
2503+
2,
2504+
&"b".to_string(),
2505+
)
2506+
.unwrap();
2507+
2508+
assert_eq!(output_column_names, expected_output_column_names);
2509+
}
2510+
2511+
#[test]
2512+
fn we_can_unwrap_err_when_mismatching_count_alias() {
2513+
let expected_output_column_names: IndexSet<String, _> =
2514+
vec!["a".to_string(), "b".to_string()].into_iter().collect();
2515+
let err = try_unwrap_output_column_names_with_count_alias(
2516+
Some(&expected_output_column_names),
2517+
3,
2518+
&"b".to_string(),
2519+
)
2520+
.unwrap_err();
2521+
2522+
assert!(matches!(err, EVMProofPlanError::InvalidOutputColumnName));
2523+
}
24582524
}

0 commit comments

Comments
 (0)