From 33bc364ebaf9f14518c423b3334c0c0f746c0b98 Mon Sep 17 00:00:00 2001 From: "bingqing.lbq" Date: Mon, 31 Mar 2025 16:15:24 +0800 Subject: [PATCH] fix in arith Committed-by: bingqing.lbq from Dev container --- .../executor/common/dyn_type/src/arith.rs | 316 +++++++----------- 1 file changed, 112 insertions(+), 204 deletions(-) diff --git a/interactive_engine/executor/common/dyn_type/src/arith.rs b/interactive_engine/executor/common/dyn_type/src/arith.rs index 2fb0f6b5b79c..58e6bb251958 100644 --- a/interactive_engine/executor/common/dyn_type/src/arith.rs +++ b/interactive_engine/executor/common/dyn_type/src/arith.rs @@ -55,96 +55,24 @@ impl std::ops::Add for Primitives { (Byte(a), Float(b)) => Float(a as f32 + b), (ULLong(a), ULLong(b)) => ULLong(a + b), (ULLong(a), ULong(b)) => ULLong(a + b as u128), - (ULLong(a), Long(b)) => { - if b < 0 { - ULLong(a - b.abs() as u128) - } else { - ULLong(a + b as u128) - } - } + (ULLong(a), Long(b)) => ULLong(a + b as u128), (ULLong(a), UInteger(b)) => ULLong(a + b as u128), - (ULLong(a), Integer(b)) => { - if b < 0 { - ULLong(a - b.abs() as u128) - } else { - ULLong(a + b as u128) - } - } - (ULLong(a), Byte(b)) => { - if b < 0 { - ULLong(a - b.abs() as u128) - } else { - ULLong(a + b as u128) - } - } + (ULLong(a), Integer(b)) => ULLong(a + b as u128), + (ULLong(a), Byte(b)) => ULLong(a + b as u128), (ULong(a), ULLong(b)) => ULLong(a as u128 + b), - (Long(a), ULLong(b)) => { - if a < 0 { - ULLong(b - a.abs() as u128) - } else { - ULLong(a as u128 + b) - } - } + (Long(a), ULLong(b)) => ULLong(a as u128 + b), (UInteger(a), ULLong(b)) => ULLong(a as u128 + b), - (Integer(a), ULLong(b)) => { - if a < 0 { - ULLong(b - a.abs() as u128) - } else { - ULLong(a as u128 + b) - } - } - (Byte(a), ULLong(b)) => { - if a < 0 { - ULLong(b - a.abs() as u128) - } else { - ULLong(a as u128 + b) - } - } + (Integer(a), ULLong(b)) => ULLong(a as u128 + b), + (Byte(a), ULLong(b)) => ULLong(a as u128 + b), (ULong(a), ULong(b)) => ULong(a + b), - (ULong(a), Long(b)) => { - if b < 0 { - ULong(a - b.abs() as u64) - } else { - ULong(a + b as u64) - } - } + (ULong(a), Long(b)) => ULong(a + b as u64), (ULong(a), UInteger(b)) => ULong(a + b as u64), - (ULong(a), Integer(b)) => { - if b < 0 { - ULong(a - b.abs() as u64) - } else { - ULong(a + b as u64) - } - } - (ULong(a), Byte(b)) => { - if b < 0 { - ULong(a - b.abs() as u64) - } else { - ULong(a + b as u64) - } - } - (Long(a), ULong(b)) => { - if a < 0 { - ULong(b - a.abs() as u64) - } else { - ULong(a as u64 + b) - } - } + (ULong(a), Integer(b)) => ULong(a + b as u64), + (ULong(a), Byte(b)) => ULong(a + b as u64), + (Long(a), ULong(b)) => ULong(a as u64 + b), (UInteger(a), ULong(b)) => ULong(a as u64 + b), - (Integer(a), ULong(b)) => { - if a < 0 { - ULong(b - a.abs() as u64) - } else { - ULong(a as u64 + b) - } - } - (Byte(a), ULong(b)) => { - if a < 0 { - ULong(b - a.abs() as u64) - } else { - ULong(a as u64 + b) - } - } + (Integer(a), ULong(b)) => ULong(a as u64 + b), + (Byte(a), ULong(b)) => ULong(a as u64 + b), (Long(a), Long(b)) => Long(a + b), (Long(a), UInteger(b)) => Long(a + b as i64), (Long(a), Integer(b)) => Long(a + b as i64), @@ -153,34 +81,10 @@ impl std::ops::Add for Primitives { (Integer(a), Long(b)) => Long(a as i64 + b), (Byte(a), Long(b)) => Long(a as i64 + b), (UInteger(a), UInteger(b)) => UInteger(a + b), - (UInteger(a), Integer(b)) => { - if b < 0 { - UInteger(a - b.abs() as u32) - } else { - UInteger(a + b as u32) - } - } - (UInteger(a), Byte(b)) => { - if b < 0 { - UInteger(a - b.abs() as u32) - } else { - UInteger(a + b as u32) - } - } - (Integer(a), UInteger(b)) => { - if a < 0 { - UInteger(b - a.abs() as u32) - } else { - UInteger(a as u32 + b) - } - } - (Byte(a), UInteger(b)) => { - if a < 0 { - UInteger(b - a.abs() as u32) - } else { - UInteger(a as u32 + b) - } - } + (UInteger(a), Integer(b)) => UInteger(a + b as u32), + (UInteger(a), Byte(b)) => UInteger(a + b as u32), + (Integer(a), UInteger(b)) => UInteger(a as u32 + b), + (Byte(a), UInteger(b)) => UInteger(a as u32 + b), (Integer(a), Integer(b)) => Integer(a + b), (Integer(a), Byte(b)) => Integer(a + b as i32), (Byte(a), Integer(b)) => Integer(a as i32 + b), @@ -227,58 +131,20 @@ impl std::ops::Sub for Primitives { (Byte(a), Float(b)) => Float(a as f32 - b), (ULLong(a), ULLong(b)) => ULLong(a - b), (ULLong(a), ULong(b)) => ULLong(a - b as u128), - (ULLong(a), Long(b)) => { - if b < 0 { - ULLong(a + b.abs() as u128) - } else { - // if a < b, a - b will overflow - ULLong(a - b as u128) - } - } + (ULLong(a), Long(b)) => ULLong(a - b as u128), (ULLong(a), UInteger(b)) => ULLong(a - b as u128), - (ULLong(a), Integer(b)) => { - if b < 0 { - ULLong(a + b.abs() as u128) - } else { - ULLong(a - b as u128) - } - } - (ULLong(a), Byte(b)) => { - if b < 0 { - ULLong(a + b.abs() as u128) - } else { - ULLong(a - b as u128) - } - } + (ULLong(a), Integer(b)) => ULLong(a - b as u128), + (ULLong(a), Byte(b)) => ULLong(a - b as u128), (ULong(a), ULLong(b)) => ULLong(a as u128 - b), - // could be an unexpected result if a < 0, so as the follows when we do the subtraction between a signed negative number and an unsigned number (Long(a), ULLong(b)) => ULLong(a as u128 - b), (UInteger(a), ULLong(b)) => ULLong(a as u128 - b), (Integer(a), ULLong(b)) => ULLong(a as u128 - b), (Byte(a), ULLong(b)) => ULLong(a as u128 - b), (ULong(a), ULong(b)) => ULong(a - b), - (ULong(a), Long(b)) => { - if b < 0 { - ULong(a + b.abs() as u64) - } else { - ULong(a - b as u64) - } - } + (ULong(a), Long(b)) => ULong(a - b as u64), (ULong(a), UInteger(b)) => ULong(a - b as u64), - (ULong(a), Integer(b)) => { - if b < 0 { - ULong(a + b.abs() as u64) - } else { - ULong(a - b as u64) - } - } - (ULong(a), Byte(b)) => { - if b < 0 { - ULong(a + b.abs() as u64) - } else { - ULong(a - b as u64) - } - } + (ULong(a), Integer(b)) => ULong(a - b as u64), + (ULong(a), Byte(b)) => ULong(a - b as u64), (Long(a), ULong(b)) => ULong(a as u64 - b), (UInteger(a), ULong(b)) => ULong(a as u64 - b), (Integer(a), ULong(b)) => ULong(a as u64 - b), @@ -291,20 +157,8 @@ impl std::ops::Sub for Primitives { (Integer(a), Long(b)) => Long(a as i64 - b), (Byte(a), Long(b)) => Long(a as i64 - b), (UInteger(a), UInteger(b)) => UInteger(a - b), - (UInteger(a), Integer(b)) => { - if b < 0 { - UInteger(a + b.abs() as u32) - } else { - UInteger(a - b as u32) - } - } - (UInteger(a), Byte(b)) => { - if b < 0 { - UInteger(a + b.abs() as u32) - } else { - UInteger(a - b as u32) - } - } + (UInteger(a), Integer(b)) => UInteger(a - b as u32), + (UInteger(a), Byte(b)) => UInteger(a - b as u32), (Integer(a), UInteger(b)) => UInteger(a as u32 - b), (Byte(a), UInteger(b)) => UInteger(a as u32 - b), (Integer(a), Integer(b)) => Integer(a - b), @@ -353,7 +207,6 @@ impl std::ops::Mul for Primitives { (Byte(a), Float(b)) => Float(a as f32 * b), (ULLong(a), ULLong(b)) => ULLong(a * b), (ULLong(a), ULong(b)) => ULLong(a * b as u128), - // could be an unexpected result if b < 0, so as the follows when we do the multiplication between a signed negative number and an unsigned number (ULLong(a), Long(b)) => ULLong(a * b as u128), (ULLong(a), UInteger(b)) => ULLong(a * b as u128), (ULLong(a), Integer(b)) => ULLong(a * b as u128), @@ -399,35 +252,6 @@ impl std::ops::Div for Primitives { fn div(self, other: Primitives) -> Self::Output { use super::Primitives::*; - // when divide by zero, - // if it is a integer division, it should panic - // if it is a float division, it should return f64::INFINITY, f64::NEG_INFINITY, or f64::NAN, following IEEE 754 standard - // currently, we follow the IEEE 754 standard for all division - if other == Byte(0) - || other == Integer(0) - || other == UInteger(0) - || other == Long(0) - || other == ULong(0) - || other == ULLong(0) - || other == Double(0.0) - || other == Float(0.0) - { - if self.is_negative() { - return Double(f64::NEG_INFINITY); - } else if self == Byte(0) - || self == Integer(0) - || self == UInteger(0) - || self == Long(0) - || self == ULong(0) - || self == ULLong(0) - || self == Double(0.0) - || self == Float(0.0) - { - return Double(f64::NAN); - } else { - return Double(f64::INFINITY); - } - } match (self, other) { (Double(a), Double(b)) => Double(a / b), (Double(a), Float(b)) => Double(a / b as f64), @@ -459,7 +283,6 @@ impl std::ops::Div for Primitives { (Byte(a), Float(b)) => Float(a as f32 / b), (ULLong(a), ULLong(b)) => ULLong(a / b), (ULLong(a), ULong(b)) => ULLong(a / b as u128), - // could be an unexpected result if b < 0, so as the follows when we do the division between a signed negative number and an unsigned number (ULLong(a), Long(b)) => ULLong(a / b as u128), (ULLong(a), UInteger(b)) => ULLong(a / b as u128), (ULLong(a), Integer(b)) => ULLong(a / b as u128), @@ -770,10 +593,43 @@ mod tests { let b: u32 = a as u32; assert_eq!(b, 4294967295); - // this will overflow and cannot pass the compilation - // let a: u128 = 1; - // let b: i64 = 2; + let a: u128 = 1; + let b: i64 = 2; + // attempt to compute `1_u128 - 2_u128`, which would overflow // let c: u128 = a - b as u128; + let c: u128 = a.wrapping_sub(b as u128); + assert_eq!(c, u128::MAX); + + // 1u32 + (-2i32) + let a: u32 = 1u32 + (-2i32) as u32; + assert_eq!(a, u32::MAX); + let b: u32 = 1u32.wrapping_add(-2i32 as u32); + assert_eq!(b, u32::MAX); + + let a: u32 = ((1u32) as i64 + (-2i32) as i64) as u32; + assert_eq!(a, u32::MAX); + let b: u32 = ((1u32) as i64).wrapping_add((-2i32) as i64) as u32; + assert_eq!(b, u32::MAX); + + // 2u32 + (-1i32) + // attempt to compute `2_u32 + u32::MAX`, which would overflow + // let c: u32 = 2u32 + (-1i32) as u32; + let c: u32 = 2u32.wrapping_add(-1i32 as u32); + assert_eq!(c, 1); + + let c: u32 = ((2u32) as i64 + (-1i32) as i64) as u32; + assert_eq!(c, 1); + let d: u32 = ((2u32) as i64).wrapping_add((-1i32) as i64) as u32; + assert_eq!(d, 1); + + let e: u32 = 4294967295; + let f: i32 = e as i32; + assert_eq!(f, -1); + + let g: u32 = 0; + let h: u32 = 1; + let i: u32 = g.wrapping_sub(h); + assert_eq!(i, u32::MAX); } #[test] @@ -1074,4 +930,56 @@ mod tests { panic!("Expected Double result"); } } + + #[test] + fn test_divide_zero() { + let x = Primitives::Integer(1); + let y = Primitives::Integer(0); + let res = panic::catch_unwind(|| x / y); + assert!(res.is_err()); + + let x = Primitives::UInteger(1); + let y = Primitives::UInteger(0); + let res = panic::catch_unwind(|| x / y); + assert!(res.is_err()); + + let x = Primitives::Float(1.0); + let neg_x = Primitives::Float(-1.0); + let zero_x = Primitives::Float(0.0); + let int_x = Primitives::Integer(1); + let y = Primitives::Float(0.0); + let res = x / y; + if let Primitives::Float(result) = res { + assert!(result.is_infinite()); + } else { + panic!("Expected Float result"); + } + let res = neg_x / y; + if let Primitives::Float(result) = res { + assert!(result.is_infinite()); + } else { + panic!("Expected Float result"); + } + let res = zero_x / y; + if let Primitives::Float(result) = res { + assert!(result.is_nan()); + } else { + panic!("Expected Float result"); + } + let res = int_x / y; + if let Primitives::Float(result) = res { + assert!(result.is_infinite()); + } else { + panic!("Expected Float result"); + } + + let x = Primitives::Double(1.0); + let y = Primitives::Double(0.0); + let res = x / y; + if let Primitives::Double(result) = res { + assert!(result.is_infinite()); + } else { + panic!("Expected Double result"); + } + } }