@@ -1694,39 +1694,41 @@ where
16941694 } )
16951695 . collect ( ) ;
16961696
1697- let mut tlc_fail = None ;
16981697 let mut hash_algorithm = HashAlgorithm :: default ( ) ;
16991698 let mut tlc_preimage_map = HashMap :: new ( ) ;
17001699
17011700 // we generally set all validation error as IncorrectOrUnknownPaymentDetails, this failure
1702- // is expected to be the scenarios malicious sender have sent a inconsistent data with MPP
1703- ' validation: {
1701+ // is expected to be the scenarios malicious sender have sent a inconsistent data with MPP,
1702+ // and we don't want to keep them in the database.
1703+ let tlc_fail = ' validation: {
1704+ macro_rules! validation_fail {
1705+ ( $msg: expr) => { {
1706+ error!( $msg, payment_hash) ;
1707+ break ' validation Some ( TlcErr :: new(
1708+ TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1709+ ) ) ;
1710+ } } ;
1711+ }
1712+
17041713 // check if all tlcs have the same total amount
17051714 if tlcs. len ( ) > 1
17061715 && !tlcs
17071716 . windows ( 2 )
17081717 . all ( |w| w[ 0 ] . total_amount == w[ 1 ] . total_amount )
17091718 {
1710- error ! ( "TLCs have inconsistent total_amount: {:?}" , tlcs) ;
1711- break ' validation tlc_fail =
1712- Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1719+ validation_fail ! ( "TLCs have inconsistent total_amount: {:?}" ) ;
17131720 }
17141721
17151722 // check if tlc set are fulfilled
17161723 let invoice = self . store . get_invoice ( & payment_hash) ;
17171724 let Some ( invoice) = invoice else {
1718- error ! (
1719- "Try to settle mpp tlc set, but invoice not found for payment hash {:?}" ,
1720- payment_hash
1721- ) ;
1722- break ' validation tlc_fail =
1723- Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1725+ validation_fail ! (
1726+ "Try to settle mpp tlc set, but invoice not found for payment hash {:?}"
1727+ )
17241728 } ;
17251729
17261730 let Some ( mpp_mode) = invoice. mpp_mode ( ) else {
1727- error ! ( "try to settle down mpp payment_hash: {:?} while the invoice does no support MPP" , payment_hash) ;
1728- break ' validation tlc_fail =
1729- Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1731+ validation_fail ! ( "try to settle down mpp payment_hash: {:?} while the invoice does no support MPP" ) ;
17301732 } ;
17311733
17321734 if !is_invoice_fulfilled ( & invoice, & tlcs) {
@@ -1738,49 +1740,45 @@ where
17381740 match mpp_mode {
17391741 MppMode :: BasicMpp => {
17401742 let Some ( preimage) = self . store . get_preimage ( & payment_hash) else {
1741- error ! (
1742- "basic MPP can not get preimage for payment: {:?}" ,
1743- payment_hash
1743+ validation_fail ! (
1744+ "basic MPP can not get preimage for payment: {:?}"
17441745 ) ;
1745- break ' validation tlc_fail = Some ( TlcErr :: new (
1746- TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1747- ) ) ;
17481746 } ;
17491747
17501748 for tlc in tlcs. iter ( ) {
17511749 tlc_preimage_map. insert ( ( tlc. channel_id , tlc. id ( ) ) , preimage) ;
17521750 }
17531751 }
17541752 MppMode :: AtomicMpp => {
1755- let mut atomic_mpp_data =
1753+ let mut tlcs_mpp_data =
17561754 self . store . get_atomic_mpp_payment_data ( & payment_hash) ;
17571755
1758- if atomic_mpp_data. len ( ) != tlcs. len ( ) {
1759- error ! (
1760- "atomic mpp don't have enough mpp data for payment_hash: {:?}" ,
1761- payment_hash
1756+ if tlcs_mpp_data. len ( ) != tlcs. len ( ) {
1757+ validation_fail ! (
1758+ "atomic mpp don't have enough mpp data for payment_hash: {:?}"
17621759 ) ;
1763- break ' validation tlc_fail = Some ( TlcErr :: new (
1764- TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1765- ) ) ;
17661760 }
17671761
1768- atomic_mpp_data . sort_by ( |a , b | a. 1 . index ( ) . cmp ( & b. 1 . index ( ) ) ) ;
1762+ tlcs_mpp_data . sort_by ( |( _ , a ) , ( _ , b ) | a. index ( ) . cmp ( & b. index ( ) ) ) ;
17691763 let index: Vec < u16 > =
1770- atomic_mpp_data. iter ( ) . map ( |a| a. 1 . index ( ) ) . collect ( ) ;
1771- let expected_index: Vec < u16 > = ( 0 ..tlcs. len ( ) as u16 ) . collect ( ) ;
1764+ tlcs_mpp_data. iter ( ) . map ( |( _, data) | data. index ( ) ) . collect ( ) ;
17721765
1766+ let total_count = tlcs_mpp_data[ 0 ] . 1 . total_amp_count ;
1767+ if tlcs_mpp_data
1768+ . iter ( )
1769+ . any ( |( _, data) | data. total_amp_count != total_count)
1770+ {
1771+ validation_fail ! ( "atomic mpp total count are not the same for payment_hash: {:?}" ) ;
1772+ }
1773+
1774+ let expected_index: Vec < u16 > = ( 0 ..total_count) . collect ( ) ;
17731775 if index != expected_index {
1774- error ! (
1775- "atomic mpp index are not expected for payment_hash: {:?}" ,
1776- payment_hash
1776+ validation_fail ! (
1777+ "atomic mpp index are not expected for payment_hash: {:?}"
17771778 ) ;
1778- break ' validation tlc_fail = Some ( TlcErr :: new (
1779- TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1780- ) ) ;
17811779 }
17821780
1783- let child_descs: Vec < AmpChildDesc > = atomic_mpp_data
1781+ let child_descs: Vec < AmpChildDesc > = tlcs_mpp_data
17841782 . iter ( )
17851783 . map ( |( _, data) | data. child_desc . clone ( ) )
17861784 . collect ( ) ;
@@ -1789,7 +1787,7 @@ where
17891787 debug_assert_eq ! ( child_descs. len( ) , children. len( ) ) ;
17901788
17911789 for ( ( ( channel_id, tlc_id) , _) , child) in
1792- atomic_mpp_data . iter ( ) . zip ( children. iter ( ) )
1790+ tlcs_mpp_data . iter ( ) . zip ( children. iter ( ) )
17931791 {
17941792 tlc_preimage_map. insert ( ( * channel_id, * tlc_id) , child. preimage ) ;
17951793 }
@@ -1803,18 +1801,15 @@ where
18031801 let hash: Hash256 = hash_algorithm. hash ( preimage. as_ref ( ) ) . into ( ) ;
18041802
18051803 if hash != payment_hash {
1806- error ! (
1807- "verify AMP preimage {:?} for payment hash {:?} is not valid" ,
1808- preimage, payment_hash
1804+ validation_fail ! (
1805+ "verify AMP preimage for payment hash {:?} is not valid"
18091806 ) ;
1810- break ' validation tlc_fail = Some ( TlcErr :: new (
1811- TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1812- ) ) ;
18131807 }
18141808 }
18151809 }
18161810 }
1817- } // end of 'validation block
1811+ None
1812+ } ; // end of 'validation block
18181813
18191814 // remove tlcs
18201815 for tlc in tlcs {
@@ -2767,16 +2762,17 @@ where
27672762 let hash_algorithm = attempts[ 0 ] . route_hops . first ( ) . unwrap ( ) . hash_algorithm ;
27682763
27692764 let total_count = attempts. len ( ) as u16 ;
2770- let amps : Vec < AmpPaymentData > = secrets
2765+ let child_descs : Vec < _ > = secrets
27712766 . iter ( )
27722767 . enumerate ( )
2773- . map ( |( i, & share) | AmpPaymentData :: new ( payment_hash , i as u16 , total_count , share) )
2768+ . map ( |( i, & share) | AmpChildDesc :: new ( i as u16 , share) )
27742769 . collect ( ) ;
27752770
2776- let child_descs: Vec < _ > = amps. iter ( ) . map ( |x| x. child_desc . clone ( ) ) . collect ( ) ;
27772771 let children = AmpChild :: construct_amp_children ( & child_descs, hash_algorithm) ;
2778- for ( ( attempt, amp_data ) , child) in attempts. iter_mut ( ) . zip ( & amps ) . zip ( & children ) {
2772+ for ( index , ( attempt, child) ) in attempts. iter_mut ( ) . zip ( & children ) . enumerate ( ) {
27792773 let last_hop = attempt. route_hops . last_mut ( ) . expect ( "last hop" ) ;
2774+ let amp_data =
2775+ AmpPaymentData :: new ( payment_hash, total_count, child_descs[ index] . clone ( ) ) ;
27802776 let mut custom_records = last_hop. custom_records . clone ( ) . unwrap_or_default ( ) ;
27812777 amp_data. write ( & mut custom_records) ;
27822778 last_hop. custom_records = Some ( custom_records) ;
0 commit comments