diff --git a/sdk/pinocchio/src/account_info.rs b/sdk/pinocchio/src/account_info.rs index 2a3c6748..91e70b76 100644 --- a/sdk/pinocchio/src/account_info.rs +++ b/sdk/pinocchio/src/account_info.rs @@ -664,7 +664,6 @@ impl<'a, T: ?Sized> Ref<'a, T> { { // Avoid decrementing the borrow flag on Drop. let orig = ManuallyDrop::new(orig); - Ref { value: NonNull::from(f(&*orig)), state: orig.state, @@ -673,6 +672,27 @@ impl<'a, T: ?Sized> Ref<'a, T> { } } + /// Tries to makes a new `Ref` for a component of the borrowed data. + /// On failure, the original guard is returned alongside with the error + /// returned by the closure. + #[inline] + pub fn try_map( + orig: Ref<'a, T>, + f: impl FnOnce(&T) -> Result<&U, E>, + ) -> Result, (Self, E)> { + // Avoid decrementing the borrow flag on Drop. + let orig = ManuallyDrop::new(orig); + match f(&*orig) { + Ok(value) => Ok(Ref { + value: NonNull::from(value), + state: orig.state, + borrow_shift: orig.borrow_shift, + marker: PhantomData, + }), + Err(e) => Err((ManuallyDrop::into_inner(orig), e)), + } + } + /// Filters and maps a reference to a new type. #[inline] pub fn filter_map(orig: Ref<'a, T>, f: F) -> Result, Self> @@ -736,7 +756,6 @@ impl<'a, T: ?Sized> RefMut<'a, T> { { // Avoid decrementing the borrow flag on Drop. let mut orig = ManuallyDrop::new(orig); - RefMut { value: NonNull::from(f(&mut *orig)), state: orig.state, @@ -745,6 +764,27 @@ impl<'a, T: ?Sized> RefMut<'a, T> { } } + /// Tries to makes a new `RefMut` for a component of the borrowed data. + /// On failure, the original guard is returned alongside with the error + /// returned by the closure. + #[inline] + pub fn try_map( + orig: RefMut<'a, T>, + f: impl FnOnce(&mut T) -> Result<&mut U, E>, + ) -> Result, (Self, E)> { + // Avoid decrementing the borrow flag on Drop. + let mut orig = ManuallyDrop::new(orig); + match f(&mut *orig) { + Ok(value) => Ok(RefMut { + value: NonNull::from(value), + state: orig.state, + borrow_bitmask: orig.borrow_bitmask, + marker: PhantomData, + }), + Err(e) => Err((ManuallyDrop::into_inner(orig), e)), + } + } + /// Filters and maps a mutable reference to a new type. #[inline] pub fn filter_map(orig: RefMut<'a, T>, f: F) -> Result, Self> @@ -753,17 +793,13 @@ impl<'a, T: ?Sized> RefMut<'a, T> { { // Avoid decrementing the mutable borrow flag on Drop. let mut orig = ManuallyDrop::new(orig); - match f(&mut *orig) { - Some(value) => { - let value = NonNull::from(value); - Ok(RefMut { - value, - state: orig.state, - borrow_bitmask: orig.borrow_bitmask, - marker: PhantomData, - }) - } + Some(value) => Ok(RefMut { + value: NonNull::from(value), + state: orig.state, + borrow_bitmask: orig.borrow_bitmask, + marker: PhantomData, + }), None => Err(ManuallyDrop::into_inner(orig)), } } @@ -821,6 +857,19 @@ mod tests { assert_eq!(state, NOT_BORROWED - (1 << DATA_BORROW_SHIFT)); assert_eq!(*new_ref, 3); + let Ok(new_ref) = Ref::try_map::<_, u8>(new_ref, |_| Ok(&4)) else { + unreachable!() + }; + + assert_eq!(state, NOT_BORROWED - (1 << DATA_BORROW_SHIFT)); + assert_eq!(*new_ref, 4); + + let (new_ref, err) = Ref::try_map::(new_ref, |_| Err(5)).unwrap_err(); + assert_eq!(state, NOT_BORROWED - (1 << DATA_BORROW_SHIFT)); + assert_eq!(err, 5); + // Unchanged + assert_eq!(*new_ref, 4); + let new_ref = Ref::filter_map(new_ref, |_| Option::<&u8>::None); assert_eq!(state, NOT_BORROWED - (1 << DATA_BORROW_SHIFT));