Skip to content

Commit 73a6a26

Browse files
authored
refactor: exclusively use DecimalError for proof-of-sql decimal-related errors (#58)
1 parent c63d33b commit 73a6a26

File tree

5 files changed

+55
-60
lines changed

5 files changed

+55
-60
lines changed

crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
use super::scalar_and_i256_conversions::convert_i256_to_scalar;
2-
use crate::{
3-
base::{database::Column, math::decimal::Precision, scalar::Scalar},
4-
sql::parse::ConversionError,
5-
};
2+
use crate::base::{database::Column, math::decimal::Precision, scalar::Scalar};
63
use arrow::{
74
array::{
85
Array, ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array,
@@ -25,15 +22,15 @@ pub enum ArrowArrayToColumnConversionError {
2522
/// This error occurs when trying to convert from an unsupported arrow type.
2623
#[error("unsupported type: attempted conversion from ArrayRef of type {0} to OwnedColumn")]
2724
UnsupportedType(DataType),
25+
/// Variant for decimal errors
26+
#[error(transparent)]
27+
DecimalError(#[from] crate::base::math::decimal::DecimalError),
2828
/// This error occurs when trying to convert from an i256 to a Scalar.
2929
#[error("decimal conversion failed: {0}")]
3030
DecimalConversionFailed(i256),
3131
/// This error occurs when the specified range is out of the bounds of the array.
3232
#[error("index out of bounds: the len is {0} but the index is {1}")]
3333
IndexOutOfBounds(usize, usize),
34-
/// Variant for conversion errors
35-
#[error("conversion error: {0}")]
36-
ConversionError(#[from] ConversionError),
3734
/// Using TimeError to handle all time-related errors
3835
#[error(transparent)]
3936
TimestampConversionError(#[from] PoSQLTimestampError),

crates/proof-of-sql/src/base/math/decimal.rs

+33-39
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
11
//! Module for parsing an `IntermediateDecimal` into a `Decimal75`.
2-
use crate::{
3-
base::{
4-
math::decimal::DecimalError::{
5-
IntermediateDecimalConversionError, InvalidPrecision, RoundingError,
6-
},
7-
scalar::Scalar,
8-
},
9-
sql::parse::{
10-
ConversionError::{self, DecimalConversionError},
11-
ConversionResult,
12-
},
13-
};
2+
use crate::base::scalar::{Scalar, ScalarConversionError};
143
use proof_of_sql_parser::intermediate_decimal::{IntermediateDecimal, IntermediateDecimalError};
154
use serde::{Deserialize, Deserializer, Serialize};
165
use thiserror::Error;
@@ -42,13 +31,17 @@ pub enum DecimalError {
4231

4332
/// Errors that may occur when parsing an intermediate decimal
4433
/// into a posql decimal
45-
#[error("Intermediate decimal conversion error: {0}")]
46-
IntermediateDecimalConversionError(IntermediateDecimalError),
34+
#[error(transparent)]
35+
IntermediateDecimalConversionError(#[from] IntermediateDecimalError),
4736
}
4837

49-
impl From<IntermediateDecimalError> for ConversionError {
50-
fn from(err: IntermediateDecimalError) -> ConversionError {
51-
DecimalConversionError(IntermediateDecimalConversionError(err))
38+
/// Result type for decimal operations.
39+
pub type DecimalResult<T> = Result<T, DecimalError>;
40+
41+
// This exists because `TryFrom<arrow::datatypes::DataType>` for `ColumnType` error is String
42+
impl From<DecimalError> for String {
43+
fn from(error: DecimalError) -> Self {
44+
error.to_string()
5245
}
5346
}
5447

@@ -59,12 +52,12 @@ pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75;
5952

6053
impl Precision {
6154
/// Constructor for creating a Precision instance
62-
pub fn new(value: u8) -> Result<Self, ConversionError> {
55+
pub fn new(value: u8) -> Result<Self, DecimalError> {
6356
if value > MAX_SUPPORTED_PRECISION || value == 0 {
64-
Err(DecimalConversionError(InvalidPrecision(format!(
57+
Err(DecimalError::InvalidPrecision(format!(
6558
"Failed to parse precision. Value of {} exceeds max supported precision of {}",
6659
value, MAX_SUPPORTED_PRECISION
67-
))))
60+
)))
6861
} else {
6962
Ok(Precision(value))
7063
}
@@ -116,48 +109,48 @@ impl<S: Scalar> Decimal<S> {
116109
&self,
117110
new_precision: Precision,
118111
new_scale: i8,
119-
) -> ConversionResult<Decimal<S>> {
112+
) -> DecimalResult<Decimal<S>> {
120113
let scale_factor = new_scale - self.scale;
121114
if scale_factor < 0 || new_precision.value() < self.precision.value() + scale_factor as u8 {
122-
return Err(DecimalConversionError(RoundingError(
115+
return Err(DecimalError::RoundingError(
123116
"Scale factor must be non-negative".to_string(),
124-
)));
117+
));
125118
}
126119
let scaled_value = scale_scalar(self.value, scale_factor)?;
127120
Ok(Decimal::new(scaled_value, new_precision, new_scale))
128121
}
129122

130123
/// Get a decimal with given precision and scale from an i64
131-
pub fn from_i64(value: i64, precision: Precision, scale: i8) -> ConversionResult<Self> {
124+
pub fn from_i64(value: i64, precision: Precision, scale: i8) -> DecimalResult<Self> {
132125
const MINIMAL_PRECISION: u8 = 19;
133126
let raw_precision = precision.value();
134127
if raw_precision < MINIMAL_PRECISION {
135-
return Err(DecimalConversionError(RoundingError(
128+
return Err(DecimalError::RoundingError(
136129
"Precision must be at least 19".to_string(),
137-
)));
130+
));
138131
}
139132
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
140-
return Err(DecimalConversionError(RoundingError(
133+
return Err(DecimalError::RoundingError(
141134
"Can not scale down a decimal".to_string(),
142-
)));
135+
));
143136
}
144137
let scaled_value = scale_scalar(S::from(&value), scale)?;
145138
Ok(Decimal::new(scaled_value, precision, scale))
146139
}
147140

148141
/// Get a decimal with given precision and scale from an i128
149-
pub fn from_i128(value: i128, precision: Precision, scale: i8) -> ConversionResult<Self> {
142+
pub fn from_i128(value: i128, precision: Precision, scale: i8) -> DecimalResult<Self> {
150143
const MINIMAL_PRECISION: u8 = 39;
151144
let raw_precision = precision.value();
152145
if raw_precision < MINIMAL_PRECISION {
153-
return Err(DecimalConversionError(RoundingError(
146+
return Err(DecimalError::RoundingError(
154147
"Precision must be at least 19".to_string(),
155-
)));
148+
));
156149
}
157150
if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 {
158-
return Err(DecimalConversionError(RoundingError(
151+
return Err(DecimalError::RoundingError(
159152
"Can not scale down a decimal".to_string(),
160-
)));
153+
));
161154
}
162155
let scaled_value = scale_scalar(S::from(&value), scale)?;
163156
Ok(Decimal::new(scaled_value, precision, scale))
@@ -169,7 +162,7 @@ impl<S: Scalar> Decimal<S> {
169162
/// the decimal to the specified `target_precision` and `target_scale`,
170163
/// and validates that the adjusted decimal does not exceed the specified precision.
171164
/// If the conversion is successful, it returns the `Scalar` representation;
172-
/// otherwise, it returns a `ConversionError` indicating the type of failure
165+
/// otherwise, it returns a `DecimalError` indicating the type of failure
173166
/// (e.g., exceeding precision limits).
174167
///
175168
/// ## Arguments
@@ -178,25 +171,26 @@ impl<S: Scalar> Decimal<S> {
178171
/// * `target_scale` - The scale (number of decimal places) to use in the scalar.
179172
///
180173
/// ## Errors
181-
/// Returns `InvalidPrecision` error if the number of digits in
174+
/// Returns `DecimalError::InvalidPrecision` error if the number of digits in
182175
/// the decimal exceeds the `target_precision` before or after adjusting for
183176
/// `target_scale`, or if the target precision is zero.
184177
pub(crate) fn try_into_to_scalar<S: Scalar>(
185178
d: &IntermediateDecimal,
186179
target_precision: Precision,
187180
target_scale: i8,
188-
) -> Result<S, ConversionError> {
181+
) -> DecimalResult<S> {
189182
d.try_into_bigint_with_precision_and_scale(target_precision.value(), target_scale)?
190183
.try_into()
184+
.map_err(|e: ScalarConversionError| DecimalError::InvalidDecimal(e.to_string()))
191185
}
192186

193187
/// Scale scalar by the given scale factor. Negative scaling is not allowed.
194188
/// Note that we do not check for overflow.
195-
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> ConversionResult<S> {
189+
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> DecimalResult<S> {
196190
if scale < 0 {
197-
return Err(DecimalConversionError(RoundingError(
191+
return Err(DecimalError::RoundingError(
198192
"Scale factor must be non-negative".to_string(),
199-
)));
193+
));
200194
}
201195
let ten = S::from(10);
202196
let mut res = s;

crates/proof-of-sql/src/base/scalar/mod.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ pub use error::ScalarConversionError;
44
mod mont_scalar;
55
#[cfg(test)]
66
mod mont_scalar_test;
7-
use crate::sql::parse::ConversionError;
87
use core::{cmp::Ordering, ops::Sub};
98
pub use mont_scalar::Curve25519Scalar;
109
pub(crate) use mont_scalar::MontScalar;
@@ -68,7 +67,7 @@ pub trait Scalar:
6867
+ std::convert::From<i32>
6968
+ std::convert::From<i16>
7069
+ std::convert::From<bool>
71-
+ TryFrom<BigInt, Error = ConversionError>
70+
+ TryFrom<BigInt, Error = ScalarConversionError>
7271
{
7372
/// The value (p - 1) / 2. This is "mid-point" of the field - the "six" on the clock.
7473
/// It is the largest signed value that can be represented in the field with the natural embedding.

crates/proof-of-sql/src/base/scalar/mont_scalar.rs

+5-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
use super::{scalar_conversion_to_int, Scalar, ScalarConversionError};
2-
use crate::{
3-
base::{
4-
math::decimal::{DecimalError, MAX_SUPPORTED_PRECISION},
5-
scalar::mont_scalar::DecimalError::InvalidDecimal,
6-
},
7-
sql::parse::{ConversionError, ConversionError::DecimalConversionError},
8-
};
2+
use crate::base::math::decimal::MAX_SUPPORTED_PRECISION;
93
use ark_ff::{BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
104
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
115
use bytemuck::TransparentWrapper;
@@ -157,8 +151,8 @@ impl<T: MontConfig<4>> MontScalar<T> {
157151
}
158152
}
159153

160-
impl<T: MontConfig<4>> TryFrom<num_bigint::BigInt> for MontScalar<T> {
161-
type Error = ConversionError;
154+
impl<T: MontConfig<4>> TryFrom<BigInt> for MontScalar<T> {
155+
type Error = ScalarConversionError;
162156

163157
fn try_from(value: BigInt) -> Result<Self, Self::Error> {
164158
// Obtain the absolute value to ignore the sign when counting digits
@@ -169,11 +163,11 @@ impl<T: MontConfig<4>> TryFrom<num_bigint::BigInt> for MontScalar<T> {
169163

170164
// Check if the number of digits exceeds the maximum precision allowed
171165
if digits.len() > MAX_SUPPORTED_PRECISION.into() {
172-
return Err(DecimalConversionError(InvalidDecimal(format!(
166+
return Err(ScalarConversionError::Overflow(format!(
173167
"Attempted to parse a number with {} digits, which exceeds the max supported precision of {}",
174168
digits.len(),
175169
MAX_SUPPORTED_PRECISION
176-
))));
170+
)));
177171
}
178172

179173
// Continue with the previous logic

crates/proof-of-sql/src/sql/parse/error.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use crate::base::{database::ColumnType, math::decimal::DecimalError};
2-
use proof_of_sql_parser::{posql_time::PoSQLTimestampError, Identifier, ResourceId};
2+
use proof_of_sql_parser::{
3+
intermediate_decimal::IntermediateDecimalError, posql_time::PoSQLTimestampError, Identifier,
4+
ResourceId,
5+
};
36
use thiserror::Error;
47

58
/// Errors from converting an intermediate AST into a provable AST.
@@ -79,6 +82,14 @@ impl From<ConversionError> for String {
7982
}
8083
}
8184

85+
impl From<IntermediateDecimalError> for ConversionError {
86+
fn from(err: IntermediateDecimalError) -> ConversionError {
87+
ConversionError::DecimalConversionError(DecimalError::IntermediateDecimalConversionError(
88+
err,
89+
))
90+
}
91+
}
92+
8293
impl ConversionError {
8394
/// Returns a `ConversionError::InvalidExpression` for non-numeric types used in numeric aggregation functions.
8495
pub fn non_numeric_expr_in_agg<S: Into<String>>(dtype: S, func: S) -> Self {

0 commit comments

Comments
 (0)