Skip to content

Commit de3335f

Browse files
refactor: refactor scale cast numerical utils (#722)
Please be sure to look over the pull request guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr. # Please go through the following checklist - [ ] The PR title and commit messages adhere to guidelines here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md. In particular `!` is used if and only if at least one breaking change has been introduced. - [ ] I have run the ci check script with `source scripts/run_ci_checks.sh`. - [ ] I have run the clean commit check script with `source scripts/check_commits.sh`, and the commit history is certified to follow clean commit guidelines as described here: https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/COMMIT_GUIDELINES.md - [ ] The latest changes from `main` have been incorporated to this PR by simple rebase if possible, if not, then conflicts are resolved appropriately. # Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the linked issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. Example: Add `NestedLoopJoinExec`. Closes #345. Since we added `HashJoinExec` in #323 it has been possible to do provable inner joins. However performance is not satisfactory in some cases. Hence we need to fix the problem by implement `NestedLoopJoinExec` and speed up the code for `HashJoinExec`. --> # What changes are included in this PR? <!-- There is no need to duplicate the description in the ticket here but it is sometimes worth providing a summary of the individual changes in this PR. Example: - Add `NestedLoopJoinExec`. - Speed up `HashJoinExec`. - Route joins to `NestedLoopJoinExec` if the outer input is sufficiently small. --> # Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? Example: Yes. -->
2 parents 33b484c + ef266b4 commit de3335f

File tree

3 files changed

+179
-118
lines changed

3 files changed

+179
-118
lines changed

crates/proof-of-sql/src/base/database/column_type_operation.rs

+92-57
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ pub fn try_cast_types(from: ColumnType, to: ColumnType) -> ColumnOperationResult
233233
/// Casting can only be supported if the resulting data type is a superset of the input data type.
234234
/// For example Deciaml(6,1) can be cast to Decimal(7,1), but not vice versa.
235235
#[expect(clippy::missing_panics_doc)]
236-
pub fn try_scale_cast_types(from: ColumnType, to: ColumnType) -> ColumnOperationResult<()> {
236+
pub fn try_decimal_scale_cast_types(from: ColumnType, to: ColumnType) -> ColumnOperationResult<()> {
237237
match (from, to) {
238238
(
239239
ColumnType::TinyInt
@@ -1003,6 +1003,7 @@ mod test {
10031003
));
10041004
}
10051005

1006+
#[expect(clippy::too_many_lines)]
10061007
#[test]
10071008
fn we_can_properly_determine_if_types_are_scale_castable() {
10081009
for from in [
@@ -1016,84 +1017,118 @@ mod test {
10161017
let from_precision = Precision::new(from.precision_value().unwrap()).unwrap();
10171018
let two_prec = Precision::new(2).unwrap();
10181019
let forty_prec = Precision::new(40).unwrap();
1019-
try_scale_cast_types(from, ColumnType::Decimal75(two_prec, 0)).unwrap_err();
1020-
try_scale_cast_types(from, ColumnType::Decimal75(two_prec, -1)).unwrap_err();
1021-
try_scale_cast_types(from, ColumnType::Decimal75(two_prec, 1)).unwrap_err();
1022-
try_scale_cast_types(from, ColumnType::Decimal75(from_precision, 0)).unwrap();
1023-
try_scale_cast_types(from, ColumnType::Decimal75(from_precision, -1)).unwrap_err();
1024-
try_scale_cast_types(from, ColumnType::Decimal75(from_precision, 1)).unwrap_err();
1025-
try_scale_cast_types(from, ColumnType::Decimal75(forty_prec, 0)).unwrap();
1026-
try_scale_cast_types(from, ColumnType::Decimal75(forty_prec, -1)).unwrap_err();
1027-
try_scale_cast_types(from, ColumnType::Decimal75(forty_prec, 1)).unwrap();
1020+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(two_prec, 0)).unwrap_err();
1021+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(two_prec, -1)).unwrap_err();
1022+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(two_prec, 1)).unwrap_err();
1023+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(from_precision, 0)).unwrap();
1024+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(from_precision, -1))
1025+
.unwrap_err();
1026+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(from_precision, 1))
1027+
.unwrap_err();
1028+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(forty_prec, 0)).unwrap();
1029+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(forty_prec, -1)).unwrap_err();
1030+
try_decimal_scale_cast_types(from, ColumnType::Decimal75(forty_prec, 1)).unwrap();
10281031
}
10291032

10301033
let twenty_prec = Precision::new(20).unwrap();
10311034

10321035
// from_with_negative_scale
10331036
let neg_scale = ColumnType::Decimal75(twenty_prec, -3);
10341037

1035-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, -4)).unwrap_err();
1036-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, -3)).unwrap();
1037-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, -2)).unwrap_err();
1038-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, 0)).unwrap_err();
1039-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, 1)).unwrap_err();
1038+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, -4))
1039+
.unwrap_err();
1040+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, -3)).unwrap();
1041+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, -2))
1042+
.unwrap_err();
1043+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, 0)).unwrap_err();
1044+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_prec, 1)).unwrap_err();
10401045

10411046
let nineteen_prec = Precision::new(19).unwrap();
1042-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, -4)).unwrap_err();
1043-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, -3)).unwrap_err();
1044-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, -2)).unwrap_err();
1045-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, 0)).unwrap_err();
1046-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, 1)).unwrap_err();
1047+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, -4))
1048+
.unwrap_err();
1049+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, -3))
1050+
.unwrap_err();
1051+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, -2))
1052+
.unwrap_err();
1053+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, 0))
1054+
.unwrap_err();
1055+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(nineteen_prec, 1))
1056+
.unwrap_err();
10471057

10481058
let twenty_one_prec = Precision::new(21).unwrap();
1049-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, -4)).unwrap_err();
1050-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, -3)).unwrap();
1051-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, -2)).unwrap();
1052-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, 0)).unwrap_err();
1053-
try_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, 1)).unwrap_err();
1059+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, -4))
1060+
.unwrap_err();
1061+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, -3))
1062+
.unwrap();
1063+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, -2))
1064+
.unwrap();
1065+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, 0))
1066+
.unwrap_err();
1067+
try_decimal_scale_cast_types(neg_scale, ColumnType::Decimal75(twenty_one_prec, 1))
1068+
.unwrap_err();
10541069

10551070
// from_with_zero_scale
10561071
let zero_scale = ColumnType::Decimal75(twenty_prec, 0);
10571072

1058-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_prec, -1)).unwrap_err();
1059-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_prec, 0)).unwrap();
1060-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_prec, 1)).unwrap_err();
1061-
1062-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, -1)).unwrap_err();
1063-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, 0)).unwrap_err();
1064-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, 1)).unwrap_err();
1065-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, 2)).unwrap_err();
1066-
1067-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, -1)).unwrap_err();
1068-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, 0)).unwrap();
1069-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, 1)).unwrap();
1070-
try_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, 2)).unwrap_err();
1073+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_prec, -1))
1074+
.unwrap_err();
1075+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_prec, 0)).unwrap();
1076+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_prec, 1))
1077+
.unwrap_err();
1078+
1079+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, -1))
1080+
.unwrap_err();
1081+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, 0))
1082+
.unwrap_err();
1083+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, 1))
1084+
.unwrap_err();
1085+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(nineteen_prec, 2))
1086+
.unwrap_err();
1087+
1088+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, -1))
1089+
.unwrap_err();
1090+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, 0))
1091+
.unwrap();
1092+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, 1))
1093+
.unwrap();
1094+
try_decimal_scale_cast_types(zero_scale, ColumnType::Decimal75(twenty_one_prec, 2))
1095+
.unwrap_err();
10711096

10721097
// from_with_positive_scale
10731098
let pos_scale = ColumnType::Decimal75(twenty_prec, 3);
10741099

1075-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, -1)).unwrap_err();
1076-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 0)).unwrap_err();
1077-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 2)).unwrap_err();
1078-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 3)).unwrap();
1079-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 4)).unwrap_err();
1080-
1081-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, -1)).unwrap_err();
1082-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 0)).unwrap_err();
1083-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 2)).unwrap_err();
1084-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 3)).unwrap_err();
1085-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 4)).unwrap_err();
1086-
1087-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, -1)).unwrap_err();
1088-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 0)).unwrap_err();
1089-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 2)).unwrap_err();
1090-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 3)).unwrap();
1091-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 4)).unwrap();
1092-
try_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 5)).unwrap_err();
1100+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, -1))
1101+
.unwrap_err();
1102+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 0)).unwrap_err();
1103+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 2)).unwrap_err();
1104+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 3)).unwrap();
1105+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_prec, 4)).unwrap_err();
1106+
1107+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, -1))
1108+
.unwrap_err();
1109+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 0))
1110+
.unwrap_err();
1111+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 2))
1112+
.unwrap_err();
1113+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 3))
1114+
.unwrap_err();
1115+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(nineteen_prec, 4))
1116+
.unwrap_err();
1117+
1118+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, -1))
1119+
.unwrap_err();
1120+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 0))
1121+
.unwrap_err();
1122+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 2))
1123+
.unwrap_err();
1124+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 3)).unwrap();
1125+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 4)).unwrap();
1126+
try_decimal_scale_cast_types(pos_scale, ColumnType::Decimal75(twenty_one_prec, 5))
1127+
.unwrap_err();
10931128
}
10941129

10951130
#[test]
10961131
fn we_cannot_scale_cast_nonsense_pairings() {
1097-
try_scale_cast_types(ColumnType::Int128, ColumnType::Boolean).unwrap_err();
1132+
try_decimal_scale_cast_types(ColumnType::Int128, ColumnType::Boolean).unwrap_err();
10981133
}
10991134
}

crates/proof-of-sql/src/base/database/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ mod slice_decimal_operation;
1414

1515
mod column_type_operation;
1616
pub use column_type_operation::{
17-
try_add_subtract_column_types, try_cast_types, try_divide_column_types,
18-
try_multiply_column_types, try_scale_cast_types,
17+
try_add_subtract_column_types, try_cast_types, try_decimal_scale_cast_types,
18+
try_divide_column_types, try_multiply_column_types,
1919
};
2020

2121
mod column_arithmetic_operation;

0 commit comments

Comments
 (0)