@@ -1663,21 +1663,6 @@ where
16631663 }
16641664 NetworkActorCommand :: SettleMPPTlcSet ( payment_hash) => {
16651665 let hold_tlcs = self . store . get_payment_hold_tlcs ( payment_hash) ;
1666- // check all hold tlcs mpp_mode is same
1667- let mpp_modes: HashSet < _ > =
1668- hold_tlcs. iter ( ) . map ( |hold_tlc| hold_tlc. mpp_mode ) . collect ( ) ;
1669- if mpp_modes. len ( ) > 1 {
1670- error ! (
1671- "payment_hash {:?} have different mpp_mode, cannot settle" ,
1672- payment_hash
1673- ) ;
1674- return Ok ( ( ) ) ;
1675- }
1676- let Some ( mpp_mode) = mpp_modes. iter ( ) . next ( ) . cloned ( ) else {
1677- error ! ( "payment_hash {:?} have not hold tlcs" , payment_hash) ;
1678- return Ok ( ( ) ) ;
1679- } ;
1680-
16811666 // load hold tlcs
16821667 let tlcs: Vec < _ > = hold_tlcs
16831668 . iter ( )
@@ -1690,86 +1675,151 @@ where
16901675
16911676 let mut tlc_fail = None ;
16921677 let mut hash_algorithm = HashAlgorithm :: default ( ) ;
1678+ let mut tlc_preimage_map = HashMap :: new ( ) ;
1679+
1680+ ' validation: {
1681+ // check all hold tlcs mpp_mode is same
1682+ let mpp_modes: HashSet < _ > =
1683+ hold_tlcs. iter ( ) . map ( |hold_tlc| hold_tlc. mpp_mode ) . collect ( ) ;
1684+ if mpp_modes. len ( ) > 1 {
1685+ error ! (
1686+ "payment_hash {:?} have different mpp_mode, cannot settle" ,
1687+ payment_hash
1688+ ) ;
1689+ tlc_fail =
1690+ Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1691+ break ' validation;
1692+ }
1693+ let Some ( mpp_mode) = mpp_modes. iter ( ) . next ( ) . cloned ( ) else {
1694+ error ! ( "payment_hash {:?} have not hold tlcs" , payment_hash) ;
1695+ return Ok ( ( ) ) ;
1696+ } ;
1697+
1698+ // check if all tlcs have the same total amount
1699+ if tlcs. len ( ) > 1
1700+ && !tlcs
1701+ . windows ( 2 )
1702+ . all ( |w| w[ 0 ] . total_amount == w[ 1 ] . total_amount )
1703+ {
1704+ error ! ( "TLCs have inconsistent total_amount: {:?}" , tlcs) ;
1705+ tlc_fail =
1706+ Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1707+ break ' validation;
1708+ }
16931709
1694- // check if all tlcs have the same total amount
1695- if tlcs. len ( ) > 1
1696- && !tlcs
1697- . windows ( 2 )
1698- . all ( |w| w[ 0 ] . total_amount == w[ 1 ] . total_amount )
1699- {
1700- error ! ( "TLCs have inconsistent total_amount: {:?}" , tlcs) ;
1701- tlc_fail = Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1702- } else {
17031710 // check if tlc set are fulfilled
1704- let Some ( invoice) = self . store . get_invoice ( & payment_hash) else {
1711+ let invoice = self . store . get_invoice ( & payment_hash) ;
1712+ let Some ( invoice) = invoice else {
17051713 error ! (
17061714 "Try to settle mpp tlc set, but invoice not found for payment hash {:?}" ,
17071715 payment_hash
17081716 ) ;
1709- return Ok ( ( ) ) ;
1717+ tlc_fail =
1718+ Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1719+ break ' validation;
17101720 } ;
1711- // just return if invoice is not fulfilled
1721+
1722+ let is_compatible = match mpp_mode {
1723+ MppMode :: BasicMpp => invoice. basic_mpp ( ) ,
1724+ MppMode :: AtomicMpp => invoice. atomic_mpp ( ) ,
1725+ } ;
1726+
1727+ if !is_compatible {
1728+ tlc_fail =
1729+ Some ( TlcErr :: new ( TlcErrorCode :: IncorrectOrUnknownPaymentDetails ) ) ;
1730+ break ' validation;
1731+ }
1732+
17121733 if !is_invoice_fulfilled ( & invoice, & tlcs) {
17131734 return Ok ( ( ) ) ;
17141735 }
1736+
17151737 hash_algorithm = invoice. hash_algorithm ( ) . cloned ( ) . unwrap_or ( hash_algorithm) ;
1716- }
1717- let mut tlc_preimages = HashMap :: new ( ) ;
1718- match mpp_mode {
1719- MppMode :: BasicMpp => {
1720- let Some ( preimage) = self . store . get_preimage ( & payment_hash) else {
1721- error ! (
1722- "basic MPP can not get preimage for payment: {:?}" ,
1723- payment_hash
1724- ) ;
1725- return Ok ( ( ) ) ;
1726- } ;
1727- for tlc in tlcs. iter ( ) {
1728- tlc_preimages. insert ( ( tlc. channel_id , tlc. id ( ) ) , preimage) ;
1729- }
1730- }
1731- MppMode :: AtomicMpp => {
1732- let mut atomic_mpp_data =
1733- self . store . get_atomic_mpp_payment_data ( & payment_hash) ;
1734- if atomic_mpp_data. len ( ) != tlcs. len ( ) {
1735- error ! (
1736- "atomic mpp don't have enough mpp data for payment_hash: {:?}" ,
1737- payment_hash
1738- ) ;
1738+ // Generate preimages based on MPP mode
1739+ match mpp_mode {
1740+ MppMode :: BasicMpp => {
1741+ let Some ( preimage) = self . store . get_preimage ( & payment_hash) else {
1742+ error ! (
1743+ "basic MPP can not get preimage for payment: {:?}" ,
1744+ payment_hash
1745+ ) ;
1746+ tlc_fail = Some ( TlcErr :: new (
1747+ TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1748+ ) ) ;
1749+ break ' validation;
1750+ } ;
1751+
1752+ for tlc in tlcs. iter ( ) {
1753+ tlc_preimage_map. insert ( ( tlc. channel_id , tlc. id ( ) ) , preimage) ;
1754+ }
17391755 }
1740- atomic_mpp_data. sort_by ( |a, b| a. 1 . index . cmp ( & b. 1 . index ) ) ;
1741- let child_descs: Vec < ChildDesc > = atomic_mpp_data
1742- . iter ( )
1743- . map ( |( _, data) | ChildDesc :: new ( data. secret , data. index ) )
1744- . collect ( ) ;
1745- let children = reconstruct_children ( & child_descs, hash_algorithm) ;
1746- debug_assert_eq ! ( child_descs. len( ) , children. len( ) ) ;
1747- for ( ( ( channel_id, tlc_id) , _) , child) in
1748- atomic_mpp_data. iter ( ) . zip ( children. iter ( ) )
1749- {
1750- tlc_preimages. insert ( ( * channel_id, * tlc_id) , child. preimage ) ;
1756+ MppMode :: AtomicMpp => {
1757+ let mut atomic_mpp_data =
1758+ self . store . get_atomic_mpp_payment_data ( & payment_hash) ;
1759+
1760+ if atomic_mpp_data. len ( ) != tlcs. len ( ) {
1761+ error ! (
1762+ "atomic mpp don't have enough mpp data for payment_hash: {:?}" ,
1763+ payment_hash
1764+ ) ;
1765+ tlc_fail = Some ( TlcErr :: new (
1766+ TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1767+ ) ) ;
1768+ break ' validation;
1769+ }
1770+
1771+ atomic_mpp_data. sort_by ( |a, b| a. 1 . index . cmp ( & b. 1 . index ) ) ;
1772+ let index: Vec < u16 > =
1773+ atomic_mpp_data. iter ( ) . map ( |a| a. 1 . index ) . collect ( ) ;
1774+ let expected_index: Vec < u16 > = ( 0 ..tlcs. len ( ) as u16 ) . collect ( ) ;
1775+
1776+ if index != expected_index {
1777+ error ! (
1778+ "atomic mpp index are not expected for payment_hash: {:?}" ,
1779+ payment_hash
1780+ ) ;
1781+ tlc_fail = Some ( TlcErr :: new (
1782+ TlcErrorCode :: IncorrectOrUnknownPaymentDetails ,
1783+ ) ) ;
1784+ break ' validation;
1785+ }
1786+
1787+ let child_descs: Vec < ChildDesc > = atomic_mpp_data
1788+ . iter ( )
1789+ . map ( |( _, data) | ChildDesc :: new ( data. secret , data. index ) )
1790+ . collect ( ) ;
1791+ let children = reconstruct_children ( & child_descs, hash_algorithm) ;
1792+ debug_assert_eq ! ( child_descs. len( ) , children. len( ) ) ;
1793+
1794+ for ( ( ( channel_id, tlc_id) , _) , child) in
1795+ atomic_mpp_data. iter ( ) . zip ( children. iter ( ) )
1796+ {
1797+ tlc_preimage_map. insert ( ( * channel_id, * tlc_id) , child. preimage ) ;
1798+ }
17511799 }
17521800 }
1753- }
1801+ } // end of 'validation block
17541802
17551803 // remove tlcs
17561804 for tlc in tlcs {
17571805 let ( send, _recv) = oneshot:: channel ( ) ;
17581806 let rpc_reply = RpcReplyPort :: from ( send) ;
1807+
17591808 let remove_reason = match tlc_fail. clone ( ) {
17601809 Some ( tlc_fail) => RemoveTlcReason :: RemoveTlcFail ( TlcErrPacket :: new (
17611810 tlc_fail,
17621811 & tlc. shared_secret ,
17631812 ) ) ,
17641813 None => {
1765- let preimage = * tlc_preimages
1814+ let preimage = * tlc_preimage_map
17661815 . get ( & ( tlc. channel_id , tlc. id ( ) ) )
17671816 . expect ( "must got preimage" ) ;
17681817 RemoveTlcReason :: RemoveTlcFulfill ( RemoveTlcFulfill {
17691818 payment_preimage : preimage,
17701819 } )
17711820 }
17721821 } ;
1822+
17731823 match state
17741824 . send_command_to_channel (
17751825 tlc. channel_id ,
0 commit comments