diff --git a/sdk/log/crate/src/lib.rs b/sdk/log/crate/src/lib.rs index 70dd84dc..2dfbebd4 100644 --- a/sdk/log/crate/src/lib.rs +++ b/sdk/log/crate/src/lib.rs @@ -199,9 +199,107 @@ mod tests { logger.clear(); - // This should have no effect. + // This should have no effect since it is a string. logger.append_with_args("0123456789", &[Argument::Precision(2)]); assert!(&*logger == "0123456789".as_bytes()); + + logger.clear(); + + logger.append_with_args(2u8, &[Argument::Precision(8)]); + assert!(&*logger == "0.00000002".as_bytes()); + + logger.clear(); + + logger.append_with_args(2u8, &[Argument::Precision(u8::MAX)]); + assert!(&*logger == "0.0000000@".as_bytes()); + + let mut logger = Logger::<20>::default(); + + logger.append_with_args(2u8, &[Argument::Precision(u8::MAX)]); + assert!(&*logger == "0.00000000000000000@".as_bytes()); + + logger.clear(); + + logger.append_with_args(20_000u16, &[Argument::Precision(10)]); + assert!(&*logger == "0.0000020000".as_bytes()); + + let mut logger = Logger::<3>::default(); + + logger.append_with_args(2u64, &[Argument::Precision(u8::MAX)]); + assert!(&*logger == "0.@".as_bytes()); + + logger.clear(); + + logger.append_with_args(2u64, &[Argument::Precision(1)]); + assert!(&*logger == "0.2".as_bytes()); + + logger.clear(); + + logger.append_with_args(-2i64, &[Argument::Precision(1)]); + assert!(&*logger == "-0@".as_bytes()); + + let mut logger = Logger::<1>::default(); + + logger.append_with_args(-2i64, &[Argument::Precision(1)]); + assert!(&*logger == "@".as_bytes()); + + let mut logger = Logger::<2>::default(); + + logger.append_with_args(-2i64, &[Argument::Precision(1)]); + assert!(&*logger == "-@".as_bytes()); + + let mut logger = Logger::<20>::default(); + + logger.append_with_args(u64::MAX, &[Argument::Precision(u8::MAX)]); + assert!(&*logger == "0.00000000000000000@".as_bytes()); + + // 255 precision + leading 0 + decimal point + let mut logger = Logger::<257>::default(); + logger.append_with_args(u64::MAX, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("0.00000000000000".as_bytes())); + assert!(logger.ends_with("18446744073709551615".as_bytes())); + + logger.clear(); + + logger.append_with_args(u32::MAX, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("0.00000000000000".as_bytes())); + assert!(logger.ends_with("4294967295".as_bytes())); + + logger.clear(); + + logger.append_with_args(u16::MAX, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("0.00000000000000".as_bytes())); + assert!(logger.ends_with("65535".as_bytes())); + + logger.clear(); + + logger.append_with_args(u8::MAX, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("0.00000000000000".as_bytes())); + assert!(logger.ends_with("255".as_bytes())); + + // 255 precision + sign + leading 0 + decimal point + let mut logger = Logger::<258>::default(); + logger.append_with_args(i64::MIN, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("-0.00000000000000".as_bytes())); + assert!(logger.ends_with("9223372036854775808".as_bytes())); + + logger.clear(); + + logger.append_with_args(i32::MIN, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("-0.00000000000000".as_bytes())); + assert!(logger.ends_with("2147483648".as_bytes())); + + logger.clear(); + + logger.append_with_args(i16::MIN, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("-0.00000000000000".as_bytes())); + assert!(logger.ends_with("32768".as_bytes())); + + logger.clear(); + + logger.append_with_args(i8::MIN, &[Argument::Precision(u8::MAX)]); + assert!(logger.starts_with("-0.00000000000000".as_bytes())); + assert!(logger.ends_with("128".as_bytes())); } #[test] @@ -235,6 +333,46 @@ mod tests { logger.append_with_args("0123456789", &[Argument::TruncateStart(9)]); assert!(&*logger == "..@".as_bytes()); + + let mut logger = Logger::<1>::default(); + + logger.append_with_args("test", &[Argument::TruncateStart(0)]); + assert!(&*logger == "".as_bytes()); + + logger.clear(); + + logger.append_with_args("test", &[Argument::TruncateStart(1)]); + assert!(&*logger == "@".as_bytes()); + + let mut logger = Logger::<2>::default(); + + logger.append_with_args("test", &[Argument::TruncateStart(2)]); + assert!(&*logger == ".@".as_bytes()); + + let mut logger = Logger::<3>::default(); + + logger.append_with_args("test", &[Argument::TruncateStart(3)]); + assert!(&*logger == "..@".as_bytes()); + + let mut logger = Logger::<1>::default(); + + logger.append_with_args("test", &[Argument::TruncateEnd(0)]); + assert!(&*logger == "".as_bytes()); + + logger.clear(); + + logger.append_with_args("test", &[Argument::TruncateEnd(1)]); + assert!(&*logger == "@".as_bytes()); + + let mut logger = Logger::<2>::default(); + + logger.append_with_args("test", &[Argument::TruncateEnd(2)]); + assert!(&*logger == ".@".as_bytes()); + + let mut logger = Logger::<3>::default(); + + logger.append_with_args("test", &[Argument::TruncateEnd(3)]); + assert!(&*logger == "..@".as_bytes()); } #[test] diff --git a/sdk/log/crate/src/logger.rs b/sdk/log/crate/src/logger.rs index 492e1a28..6beaaab6 100644 --- a/sdk/log/crate/src/logger.rs +++ b/sdk/log/crate/src/logger.rs @@ -1,4 +1,6 @@ -use core::{mem::MaybeUninit, ops::Deref, slice::from_raw_parts}; +use core::{ + cmp::min, mem::MaybeUninit, ops::Deref, ptr::copy_nonoverlapping, slice::from_raw_parts, +}; #[cfg(all(target_os = "solana", not(target_feature = "static-syscalls")))] mod syscalls { @@ -8,6 +10,8 @@ mod syscalls { pub fn sol_memcpy_(dst: *mut u8, src: *const u8, n: u64); + pub fn sol_memset_(s: *mut u8, c: u8, n: u64); + pub fn sol_remaining_compute_units() -> u64; } } @@ -26,6 +30,12 @@ mod syscalls { syscall(dest, src, n) } + pub(crate) fn sol_memset_(s: *mut u8, c: u8, n: u64) { + let syscall: extern "C" fn(*mut u8, u8, u64) = + unsafe { core::mem::transmute(930151202u64) }; // murmur32 hash of "sol_memset_" + syscall(s, c, n) + } + pub(crate) fn sol_remaining_compute_units() -> u64 { let syscall: extern "C" fn() -> u64 = unsafe { core::mem::transmute(3991886574u64) }; // murmur32 hash of "sol_remaining_compute_units" syscall() @@ -248,7 +258,7 @@ macro_rules! impl_log_for_unsigned_integer { value /= 10; offset -= 1; // SAFETY: the offset is always within the bounds of the array since - // the `offset` is initialized with the maximum number of digits that + // `offset` is initialized with the maximum number of digits that // the type can have and decremented on each iteration; `remainder` // is always less than 10. unsafe { @@ -267,99 +277,132 @@ macro_rules! impl_log_for_unsigned_integer { 0 }; - // Number of digits written. - let mut written = MAX_DIGITS - offset; - - if precision > 0 { - while precision >= written { - written += 1; - offset -= 1; - // SAFETY: the offset is always within the bounds of the array since - // the `offset` is initialized with the maximum number of digits that - // the type can have and decremented on each iteration. - unsafe { - digits.get_unchecked_mut(offset).write(b'0'); - } - } - // Space for the decimal point. - written += 1; - } - - // Size of the buffer. + let written = MAX_DIGITS - offset; let length = buffer.len(); - // Determines if the value was truncated or not by calculating the - // number of digits that can be written. - let (overflow, written, fraction) = if written <= length { - (false, written, precision) - } else { - (true, length, precision.saturating_sub(written - length)) + + // Space required with the specified precision. We might need + // to add leading zeros and a decimal point, but this is only + // if the precision is greater than zero. + let required = match precision { + 0 => written, + // decimal point + _precision if precision < written => written + 1, + // decimal point + one leading zero + _ => precision + 2, }; + // Determines whether the value will be truncated or not. + let is_truncated = required > length; + // Cap the number of digits to write to the buffer length. + let digits_to_write = min(MAX_DIGITS - offset, length); + // SAFETY: the length of both `digits` and `buffer` arrays are guaranteed - // to be within bounds and the `written` value is always less than their - // maximum length. + // to be within bounds and the `digits_to_write` value is capped to the + // length of the `buffer`. unsafe { let source = digits.as_ptr().add(offset); let ptr = buffer.as_mut_ptr(); - #[cfg(target_os = "solana")] - { - if precision == 0 { - syscalls::sol_memcpy_( - ptr as *mut _, - source as *const _, - written as u64, - ); - } else { - // Integer part of the number. - let integer_part = written - (fraction + 1); - syscalls::sol_memcpy_( - ptr as *mut _, - source as *const _, - integer_part as u64, + // Copy the number to the buffer if no precision is specified. + if precision == 0 { + #[cfg(target_os = "solana")] + syscalls::sol_memcpy_( + ptr as *mut _, + source as *const _, + digits_to_write as u64, + ); + #[cfg(not(target_os = "solana"))] + copy_nonoverlapping(source, ptr, digits_to_write); + } + // If padding is needed to satisfy the precision, add leading zeros + // and a decimal point. + else if precision >= digits_to_write { + // Prefix. + (ptr as *mut u8).write(b'0'); + + if length > 2 { + (ptr.add(1) as *mut u8).write(b'.'); + let padding = min(length - 2, precision - digits_to_write); + + // Precision padding. + #[cfg(target_os = "solana")] + syscalls::sol_memset_( + ptr.add(2) as *mut _, + b'0', + padding as u64, ); - - // Decimal point. - (ptr.add(integer_part) as *mut u8).write(b'.'); + #[cfg(not(target_os = "solana"))] + (ptr.add(2) as *mut u8).write_bytes(b'0', padding); + + let current = 2 + padding; + + // If there is still space, copy (part of) the number. + if current < length { + let remaining = min(digits_to_write, length - current); + + // Number part. + #[cfg(target_os = "solana")] + syscalls::sol_memcpy_( + ptr.add(current) as *mut _, + source as *const _, + remaining as u64, + ); + #[cfg(not(target_os = "solana"))] + copy_nonoverlapping(source, ptr.add(current), remaining); + } + } + } + // No padding is needed, calculate the integer and fractional + // parts and add a decimal point. + else { + let integer_part = digits_to_write - precision; + + // Integer part of the number. + #[cfg(target_os = "solana")] + syscalls::sol_memcpy_( + ptr as *mut _, + source as *const _, + integer_part as u64, + ); + #[cfg(not(target_os = "solana"))] + copy_nonoverlapping(source, ptr, integer_part); + + // Decimal point. + (ptr.add(integer_part) as *mut u8).write(b'.'); + let current = integer_part + 1; + + // If there is still space, copy (part of) the remaining. + if current < length { + let remaining = min(precision, length - current); // Fractional part of the number. + #[cfg(target_os = "solana")] syscalls::sol_memcpy_( - ptr.add(integer_part + 1) as *mut _, + ptr.add(current) as *mut _, source.add(integer_part) as *const _, - fraction as u64, + remaining as u64, ); - } - } - - #[cfg(not(target_os = "solana"))] - { - if precision == 0 { - core::ptr::copy_nonoverlapping(source, ptr, written); - } else { - // Integer part of the number. - let integer_part = written - (fraction + 1); - core::ptr::copy_nonoverlapping(source, ptr, integer_part); - - // Decimal point. - (ptr.add(integer_part) as *mut u8).write(b'.'); - - // Fractional part of the number. - core::ptr::copy_nonoverlapping( + #[cfg(not(target_os = "solana"))] + copy_nonoverlapping( source.add(integer_part), - ptr.add(integer_part + 1), - fraction, + ptr.add(current), + remaining, ); } } } - // There might not have been space for all the value. - if overflow { - // SAFETY: the buffer is checked to be within `written` bounds. + let written = min(required, length); + + // There might not have been space. + if is_truncated { + // SAFETY: `written` is capped to the length of the buffer and + // the required length (`required` is always greater than zero); + // `buffer` is guaranteed to have a length of at least 1. unsafe { - let last = buffer.get_unchecked_mut(written - 1); - last.write(TRUNCATED); + buffer.get_unchecked_mut(written - 1).write(TRUNCATED); } } + written } } @@ -399,6 +442,15 @@ macro_rules! impl_log_for_signed { let mut prefix = 0; if *self < 0 { + if buffer.len() == 1 { + // SAFETY: the buffer is checked to be non-empty. + unsafe { + buffer.get_unchecked_mut(0).write(TRUNCATED); + } + // There is no space for the number, so just return. + return 1; + } + // SAFETY: the buffer is checked to be non-empty. unsafe { buffer.get_unchecked_mut(0).write(b'-'); @@ -511,7 +563,7 @@ unsafe impl Log for &str { // No truncate arguments were provided, so the entire `str` is copied to the buffer // if it fits; otherwise indicates that the `str` was truncated. if truncate_end.is_none() { - let length = core::cmp::min(size, self.len()); + let length = min(size, self.len()); ( buffer.as_mut_ptr(), self.as_ptr(), @@ -520,7 +572,7 @@ unsafe impl Log for &str { length != self.len(), ) } else { - let max_length = core::cmp::min(size, buffer.len()); + let max_length = min(size, buffer.len()); let ptr = buffer.as_mut_ptr(); // The buffer is large enough to hold the entire `str`, so no need to use the @@ -546,7 +598,7 @@ unsafe impl Log for &str { ) }; // Copy the truncated slice to the buffer. - core::ptr::copy_nonoverlapping( + copy_nonoverlapping( TRUNCATED_SLICE.as_ptr(), ptr.add(offset) as *mut _, TRUNCATED_SLICE.len(), @@ -562,17 +614,26 @@ unsafe impl Log for &str { } }; - // SAFETY: the `destination` is always within `length_to_write` bounds. - unsafe { - core::ptr::copy_nonoverlapping(source, destination as *mut _, length_to_write); - } - - // There might not have been space for all the value. - if truncated { + if length_to_write > 0 { // SAFETY: the `destination` is always within `length_to_write` bounds. unsafe { - let last = buffer.get_unchecked_mut(length_to_write - 1); - last.write(TRUNCATED); + #[cfg(target_os = "solana")] + syscalls::sol_memcpy_( + destination as *mut _, + source as *const _, + length_to_write as u64, + ); + #[cfg(not(target_os = "solana"))] + copy_nonoverlapping(source, destination as *mut _, length_to_write); + } + + // There might not have been space for all the value. + if truncated { + // SAFETY: the `destination` is always within `length_to_write` bounds. + unsafe { + let last = buffer.get_unchecked_mut(length_to_write - 1); + last.write(TRUNCATED); + } } }