Skip to content

fix(interactive): fix some arith implementation in Insight #4594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 112 additions & 204 deletions interactive_engine/executor/common/dyn_type/src/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,96 +55,24 @@ impl std::ops::Add for Primitives {
(Byte(a), Float(b)) => Float(a as f32 + b),
(ULLong(a), ULLong(b)) => ULLong(a + b),
(ULLong(a), ULong(b)) => ULLong(a + b as u128),
(ULLong(a), Long(b)) => {
if b < 0 {
ULLong(a - b.abs() as u128)
} else {
ULLong(a + b as u128)
}
}
(ULLong(a), Long(b)) => ULLong(a + b as u128),
(ULLong(a), UInteger(b)) => ULLong(a + b as u128),
(ULLong(a), Integer(b)) => {
if b < 0 {
ULLong(a - b.abs() as u128)
} else {
ULLong(a + b as u128)
}
}
(ULLong(a), Byte(b)) => {
if b < 0 {
ULLong(a - b.abs() as u128)
} else {
ULLong(a + b as u128)
}
}
(ULLong(a), Integer(b)) => ULLong(a + b as u128),
(ULLong(a), Byte(b)) => ULLong(a + b as u128),
(ULong(a), ULLong(b)) => ULLong(a as u128 + b),
(Long(a), ULLong(b)) => {
if a < 0 {
ULLong(b - a.abs() as u128)
} else {
ULLong(a as u128 + b)
}
}
(Long(a), ULLong(b)) => ULLong(a as u128 + b),
(UInteger(a), ULLong(b)) => ULLong(a as u128 + b),
(Integer(a), ULLong(b)) => {
if a < 0 {
ULLong(b - a.abs() as u128)
} else {
ULLong(a as u128 + b)
}
}
(Byte(a), ULLong(b)) => {
if a < 0 {
ULLong(b - a.abs() as u128)
} else {
ULLong(a as u128 + b)
}
}
(Integer(a), ULLong(b)) => ULLong(a as u128 + b),
(Byte(a), ULLong(b)) => ULLong(a as u128 + b),
(ULong(a), ULong(b)) => ULong(a + b),
(ULong(a), Long(b)) => {
if b < 0 {
ULong(a - b.abs() as u64)
} else {
ULong(a + b as u64)
}
}
(ULong(a), Long(b)) => ULong(a + b as u64),
(ULong(a), UInteger(b)) => ULong(a + b as u64),
(ULong(a), Integer(b)) => {
if b < 0 {
ULong(a - b.abs() as u64)
} else {
ULong(a + b as u64)
}
}
(ULong(a), Byte(b)) => {
if b < 0 {
ULong(a - b.abs() as u64)
} else {
ULong(a + b as u64)
}
}
(Long(a), ULong(b)) => {
if a < 0 {
ULong(b - a.abs() as u64)
} else {
ULong(a as u64 + b)
}
}
(ULong(a), Integer(b)) => ULong(a + b as u64),
(ULong(a), Byte(b)) => ULong(a + b as u64),
(Long(a), ULong(b)) => ULong(a as u64 + b),
(UInteger(a), ULong(b)) => ULong(a as u64 + b),
(Integer(a), ULong(b)) => {
if a < 0 {
ULong(b - a.abs() as u64)
} else {
ULong(a as u64 + b)
}
}
(Byte(a), ULong(b)) => {
if a < 0 {
ULong(b - a.abs() as u64)
} else {
ULong(a as u64 + b)
}
}
(Integer(a), ULong(b)) => ULong(a as u64 + b),
(Byte(a), ULong(b)) => ULong(a as u64 + b),
(Long(a), Long(b)) => Long(a + b),
(Long(a), UInteger(b)) => Long(a + b as i64),
(Long(a), Integer(b)) => Long(a + b as i64),
Expand All @@ -153,34 +81,10 @@ impl std::ops::Add for Primitives {
(Integer(a), Long(b)) => Long(a as i64 + b),
(Byte(a), Long(b)) => Long(a as i64 + b),
(UInteger(a), UInteger(b)) => UInteger(a + b),
(UInteger(a), Integer(b)) => {
if b < 0 {
UInteger(a - b.abs() as u32)
} else {
UInteger(a + b as u32)
}
}
(UInteger(a), Byte(b)) => {
if b < 0 {
UInteger(a - b.abs() as u32)
} else {
UInteger(a + b as u32)
}
}
(Integer(a), UInteger(b)) => {
if a < 0 {
UInteger(b - a.abs() as u32)
} else {
UInteger(a as u32 + b)
}
}
(Byte(a), UInteger(b)) => {
if a < 0 {
UInteger(b - a.abs() as u32)
} else {
UInteger(a as u32 + b)
}
}
(UInteger(a), Integer(b)) => UInteger(a + b as u32),
(UInteger(a), Byte(b)) => UInteger(a + b as u32),
(Integer(a), UInteger(b)) => UInteger(a as u32 + b),
(Byte(a), UInteger(b)) => UInteger(a as u32 + b),
(Integer(a), Integer(b)) => Integer(a + b),
(Integer(a), Byte(b)) => Integer(a + b as i32),
(Byte(a), Integer(b)) => Integer(a as i32 + b),
Expand Down Expand Up @@ -227,58 +131,20 @@ impl std::ops::Sub for Primitives {
(Byte(a), Float(b)) => Float(a as f32 - b),
(ULLong(a), ULLong(b)) => ULLong(a - b),
(ULLong(a), ULong(b)) => ULLong(a - b as u128),
(ULLong(a), Long(b)) => {
if b < 0 {
ULLong(a + b.abs() as u128)
} else {
// if a < b, a - b will overflow
ULLong(a - b as u128)
}
}
(ULLong(a), Long(b)) => ULLong(a - b as u128),
(ULLong(a), UInteger(b)) => ULLong(a - b as u128),
(ULLong(a), Integer(b)) => {
if b < 0 {
ULLong(a + b.abs() as u128)
} else {
ULLong(a - b as u128)
}
}
(ULLong(a), Byte(b)) => {
if b < 0 {
ULLong(a + b.abs() as u128)
} else {
ULLong(a - b as u128)
}
}
(ULLong(a), Integer(b)) => ULLong(a - b as u128),
(ULLong(a), Byte(b)) => ULLong(a - b as u128),
(ULong(a), ULLong(b)) => ULLong(a as u128 - b),
// could be an unexpected result if a < 0, so as the follows when we do the subtraction between a signed negative number and an unsigned number
(Long(a), ULLong(b)) => ULLong(a as u128 - b),
(UInteger(a), ULLong(b)) => ULLong(a as u128 - b),
(Integer(a), ULLong(b)) => ULLong(a as u128 - b),
(Byte(a), ULLong(b)) => ULLong(a as u128 - b),
(ULong(a), ULong(b)) => ULong(a - b),
(ULong(a), Long(b)) => {
if b < 0 {
ULong(a + b.abs() as u64)
} else {
ULong(a - b as u64)
}
}
(ULong(a), Long(b)) => ULong(a - b as u64),
(ULong(a), UInteger(b)) => ULong(a - b as u64),
(ULong(a), Integer(b)) => {
if b < 0 {
ULong(a + b.abs() as u64)
} else {
ULong(a - b as u64)
}
}
(ULong(a), Byte(b)) => {
if b < 0 {
ULong(a + b.abs() as u64)
} else {
ULong(a - b as u64)
}
}
(ULong(a), Integer(b)) => ULong(a - b as u64),
(ULong(a), Byte(b)) => ULong(a - b as u64),
(Long(a), ULong(b)) => ULong(a as u64 - b),
(UInteger(a), ULong(b)) => ULong(a as u64 - b),
(Integer(a), ULong(b)) => ULong(a as u64 - b),
Expand All @@ -291,20 +157,8 @@ impl std::ops::Sub for Primitives {
(Integer(a), Long(b)) => Long(a as i64 - b),
(Byte(a), Long(b)) => Long(a as i64 - b),
(UInteger(a), UInteger(b)) => UInteger(a - b),
(UInteger(a), Integer(b)) => {
if b < 0 {
UInteger(a + b.abs() as u32)
} else {
UInteger(a - b as u32)
}
}
(UInteger(a), Byte(b)) => {
if b < 0 {
UInteger(a + b.abs() as u32)
} else {
UInteger(a - b as u32)
}
}
(UInteger(a), Integer(b)) => UInteger(a - b as u32),
(UInteger(a), Byte(b)) => UInteger(a - b as u32),
(Integer(a), UInteger(b)) => UInteger(a as u32 - b),
(Byte(a), UInteger(b)) => UInteger(a as u32 - b),
(Integer(a), Integer(b)) => Integer(a - b),
Expand Down Expand Up @@ -353,7 +207,6 @@ impl std::ops::Mul for Primitives {
(Byte(a), Float(b)) => Float(a as f32 * b),
(ULLong(a), ULLong(b)) => ULLong(a * b),
(ULLong(a), ULong(b)) => ULLong(a * b as u128),
// could be an unexpected result if b < 0, so as the follows when we do the multiplication between a signed negative number and an unsigned number
(ULLong(a), Long(b)) => ULLong(a * b as u128),
(ULLong(a), UInteger(b)) => ULLong(a * b as u128),
(ULLong(a), Integer(b)) => ULLong(a * b as u128),
Expand Down Expand Up @@ -399,35 +252,6 @@ impl std::ops::Div for Primitives {

fn div(self, other: Primitives) -> Self::Output {
use super::Primitives::*;
// when divide by zero,
// if it is a integer division, it should panic
// if it is a float division, it should return f64::INFINITY, f64::NEG_INFINITY, or f64::NAN, following IEEE 754 standard
// currently, we follow the IEEE 754 standard for all division
if other == Byte(0)
|| other == Integer(0)
|| other == UInteger(0)
|| other == Long(0)
|| other == ULong(0)
|| other == ULLong(0)
|| other == Double(0.0)
|| other == Float(0.0)
{
if self.is_negative() {
return Double(f64::NEG_INFINITY);
} else if self == Byte(0)
|| self == Integer(0)
|| self == UInteger(0)
|| self == Long(0)
|| self == ULong(0)
|| self == ULLong(0)
|| self == Double(0.0)
|| self == Float(0.0)
{
return Double(f64::NAN);
} else {
return Double(f64::INFINITY);
}
}
match (self, other) {
(Double(a), Double(b)) => Double(a / b),
(Double(a), Float(b)) => Double(a / b as f64),
Expand Down Expand Up @@ -459,7 +283,6 @@ impl std::ops::Div for Primitives {
(Byte(a), Float(b)) => Float(a as f32 / b),
(ULLong(a), ULLong(b)) => ULLong(a / b),
(ULLong(a), ULong(b)) => ULLong(a / b as u128),
// could be an unexpected result if b < 0, so as the follows when we do the division between a signed negative number and an unsigned number
(ULLong(a), Long(b)) => ULLong(a / b as u128),
(ULLong(a), UInteger(b)) => ULLong(a / b as u128),
(ULLong(a), Integer(b)) => ULLong(a / b as u128),
Expand Down Expand Up @@ -770,10 +593,43 @@ mod tests {
let b: u32 = a as u32;
assert_eq!(b, 4294967295);

// this will overflow and cannot pass the compilation
// let a: u128 = 1;
// let b: i64 = 2;
let a: u128 = 1;
let b: i64 = 2;
// attempt to compute `1_u128 - 2_u128`, which would overflow
// let c: u128 = a - b as u128;
let c: u128 = a.wrapping_sub(b as u128);
assert_eq!(c, u128::MAX);

// 1u32 + (-2i32)
let a: u32 = 1u32 + (-2i32) as u32;
assert_eq!(a, u32::MAX);
let b: u32 = 1u32.wrapping_add(-2i32 as u32);
assert_eq!(b, u32::MAX);

let a: u32 = ((1u32) as i64 + (-2i32) as i64) as u32;
assert_eq!(a, u32::MAX);
let b: u32 = ((1u32) as i64).wrapping_add((-2i32) as i64) as u32;
assert_eq!(b, u32::MAX);

// 2u32 + (-1i32)
// attempt to compute `2_u32 + u32::MAX`, which would overflow
// let c: u32 = 2u32 + (-1i32) as u32;
let c: u32 = 2u32.wrapping_add(-1i32 as u32);
assert_eq!(c, 1);

let c: u32 = ((2u32) as i64 + (-1i32) as i64) as u32;
assert_eq!(c, 1);
let d: u32 = ((2u32) as i64).wrapping_add((-1i32) as i64) as u32;
assert_eq!(d, 1);

let e: u32 = 4294967295;
let f: i32 = e as i32;
assert_eq!(f, -1);

let g: u32 = 0;
let h: u32 = 1;
let i: u32 = g.wrapping_sub(h);
assert_eq!(i, u32::MAX);
}

#[test]
Expand Down Expand Up @@ -1074,4 +930,56 @@ mod tests {
panic!("Expected Double result");
}
}

#[test]
fn test_divide_zero() {
let x = Primitives::Integer(1);
let y = Primitives::Integer(0);
let res = panic::catch_unwind(|| x / y);
assert!(res.is_err());

let x = Primitives::UInteger(1);
let y = Primitives::UInteger(0);
let res = panic::catch_unwind(|| x / y);
assert!(res.is_err());

let x = Primitives::Float(1.0);
let neg_x = Primitives::Float(-1.0);
let zero_x = Primitives::Float(0.0);
let int_x = Primitives::Integer(1);
let y = Primitives::Float(0.0);
let res = x / y;
if let Primitives::Float(result) = res {
assert!(result.is_infinite());
} else {
panic!("Expected Float result");
}
let res = neg_x / y;
if let Primitives::Float(result) = res {
assert!(result.is_infinite());
} else {
panic!("Expected Float result");
}
let res = zero_x / y;
if let Primitives::Float(result) = res {
assert!(result.is_nan());
} else {
panic!("Expected Float result");
}
let res = int_x / y;
if let Primitives::Float(result) = res {
assert!(result.is_infinite());
} else {
panic!("Expected Float result");
}

let x = Primitives::Double(1.0);
let y = Primitives::Double(0.0);
let res = x / y;
if let Primitives::Double(result) = res {
assert!(result.is_infinite());
} else {
panic!("Expected Double result");
}
}
}
Loading