Skip to content

Commit 85fa11f

Browse files
committed
code cleanup
1 parent f27abf3 commit 85fa11f

File tree

3 files changed

+55
-57
lines changed

3 files changed

+55
-57
lines changed

crates/fiber-lib/src/fiber/network.rs

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,39 +1694,41 @@ where
16941694
})
16951695
.collect();
16961696

1697-
let mut tlc_fail = None;
16981697
let mut hash_algorithm = HashAlgorithm::default();
16991698
let mut tlc_preimage_map = HashMap::new();
17001699

17011700
// we generally set all validation error as IncorrectOrUnknownPaymentDetails, this failure
1702-
// is expected to be the scenarios malicious sender have sent a inconsistent data with MPP
1703-
'validation: {
1701+
// is expected to be the scenarios malicious sender have sent a inconsistent data with MPP,
1702+
// and we don't want to keep them in the database.
1703+
let tlc_fail = 'validation: {
1704+
macro_rules! validation_fail {
1705+
($msg:expr) => {{
1706+
error!($msg, payment_hash);
1707+
break 'validation Some(TlcErr::new(
1708+
TlcErrorCode::IncorrectOrUnknownPaymentDetails,
1709+
));
1710+
}};
1711+
}
1712+
17041713
// check if all tlcs have the same total amount
17051714
if tlcs.len() > 1
17061715
&& !tlcs
17071716
.windows(2)
17081717
.all(|w| w[0].total_amount == w[1].total_amount)
17091718
{
1710-
error!("TLCs have inconsistent total_amount: {:?}", tlcs);
1711-
break 'validation tlc_fail =
1712-
Some(TlcErr::new(TlcErrorCode::IncorrectOrUnknownPaymentDetails));
1719+
validation_fail!("TLCs have inconsistent total_amount: {:?}");
17131720
}
17141721

17151722
// check if tlc set are fulfilled
17161723
let invoice = self.store.get_invoice(&payment_hash);
17171724
let Some(invoice) = invoice else {
1718-
error!(
1719-
"Try to settle mpp tlc set, but invoice not found for payment hash {:?}",
1720-
payment_hash
1721-
);
1722-
break 'validation tlc_fail =
1723-
Some(TlcErr::new(TlcErrorCode::IncorrectOrUnknownPaymentDetails));
1725+
validation_fail!(
1726+
"Try to settle mpp tlc set, but invoice not found for payment hash {:?}"
1727+
)
17241728
};
17251729

17261730
let Some(mpp_mode) = invoice.mpp_mode() else {
1727-
error!("try to settle down mpp payment_hash: {:?} while the invoice does no support MPP", payment_hash);
1728-
break 'validation tlc_fail =
1729-
Some(TlcErr::new(TlcErrorCode::IncorrectOrUnknownPaymentDetails));
1731+
validation_fail!("try to settle down mpp payment_hash: {:?} while the invoice does no support MPP");
17301732
};
17311733

17321734
if !is_invoice_fulfilled(&invoice, &tlcs) {
@@ -1738,49 +1740,45 @@ where
17381740
match mpp_mode {
17391741
MppMode::BasicMpp => {
17401742
let Some(preimage) = self.store.get_preimage(&payment_hash) else {
1741-
error!(
1742-
"basic MPP can not get preimage for payment: {:?}",
1743-
payment_hash
1743+
validation_fail!(
1744+
"basic MPP can not get preimage for payment: {:?}"
17441745
);
1745-
break 'validation tlc_fail = Some(TlcErr::new(
1746-
TlcErrorCode::IncorrectOrUnknownPaymentDetails,
1747-
));
17481746
};
17491747

17501748
for tlc in tlcs.iter() {
17511749
tlc_preimage_map.insert((tlc.channel_id, tlc.id()), preimage);
17521750
}
17531751
}
17541752
MppMode::AtomicMpp => {
1755-
let mut atomic_mpp_data =
1753+
let mut tlcs_mpp_data =
17561754
self.store.get_atomic_mpp_payment_data(&payment_hash);
17571755

1758-
if atomic_mpp_data.len() != tlcs.len() {
1759-
error!(
1760-
"atomic mpp don't have enough mpp data for payment_hash: {:?}",
1761-
payment_hash
1756+
if tlcs_mpp_data.len() != tlcs.len() {
1757+
validation_fail!(
1758+
"atomic mpp don't have enough mpp data for payment_hash: {:?}"
17621759
);
1763-
break 'validation tlc_fail = Some(TlcErr::new(
1764-
TlcErrorCode::IncorrectOrUnknownPaymentDetails,
1765-
));
17661760
}
17671761

1768-
atomic_mpp_data.sort_by(|a, b| a.1.index().cmp(&b.1.index()));
1762+
tlcs_mpp_data.sort_by(|(_, a), (_, b)| a.index().cmp(&b.index()));
17691763
let index: Vec<u16> =
1770-
atomic_mpp_data.iter().map(|a| a.1.index()).collect();
1771-
let expected_index: Vec<u16> = (0..tlcs.len() as u16).collect();
1764+
tlcs_mpp_data.iter().map(|(_, data)| data.index()).collect();
17721765

1766+
let total_count = tlcs_mpp_data[0].1.total_amp_count;
1767+
if tlcs_mpp_data
1768+
.iter()
1769+
.any(|(_, data)| data.total_amp_count != total_count)
1770+
{
1771+
validation_fail!("atomic mpp total count are not the same for payment_hash: {:?}");
1772+
}
1773+
1774+
let expected_index: Vec<u16> = (0..total_count).collect();
17731775
if index != expected_index {
1774-
error!(
1775-
"atomic mpp index are not expected for payment_hash: {:?}",
1776-
payment_hash
1776+
validation_fail!(
1777+
"atomic mpp index are not expected for payment_hash: {:?}"
17771778
);
1778-
break 'validation tlc_fail = Some(TlcErr::new(
1779-
TlcErrorCode::IncorrectOrUnknownPaymentDetails,
1780-
));
17811779
}
17821780

1783-
let child_descs: Vec<AmpChildDesc> = atomic_mpp_data
1781+
let child_descs: Vec<AmpChildDesc> = tlcs_mpp_data
17841782
.iter()
17851783
.map(|(_, data)| data.child_desc.clone())
17861784
.collect();
@@ -1789,7 +1787,7 @@ where
17891787
debug_assert_eq!(child_descs.len(), children.len());
17901788

17911789
for (((channel_id, tlc_id), _), child) in
1792-
atomic_mpp_data.iter().zip(children.iter())
1790+
tlcs_mpp_data.iter().zip(children.iter())
17931791
{
17941792
tlc_preimage_map.insert((*channel_id, *tlc_id), child.preimage);
17951793
}
@@ -1803,18 +1801,15 @@ where
18031801
let hash: Hash256 = hash_algorithm.hash(preimage.as_ref()).into();
18041802

18051803
if hash != payment_hash {
1806-
error!(
1807-
"verify AMP preimage {:?} for payment hash {:?} is not valid",
1808-
preimage, payment_hash
1804+
validation_fail!(
1805+
"verify AMP preimage for payment hash {:?} is not valid"
18091806
);
1810-
break 'validation tlc_fail = Some(TlcErr::new(
1811-
TlcErrorCode::IncorrectOrUnknownPaymentDetails,
1812-
));
18131807
}
18141808
}
18151809
}
18161810
}
1817-
} // end of 'validation block
1811+
None
1812+
}; // end of 'validation block
18181813

18191814
// remove tlcs
18201815
for tlc in tlcs {
@@ -2767,16 +2762,17 @@ where
27672762
let hash_algorithm = attempts[0].route_hops.first().unwrap().hash_algorithm;
27682763

27692764
let total_count = attempts.len() as u16;
2770-
let amps: Vec<AmpPaymentData> = secrets
2765+
let child_descs: Vec<_> = secrets
27712766
.iter()
27722767
.enumerate()
2773-
.map(|(i, &share)| AmpPaymentData::new(payment_hash, i as u16, total_count, share))
2768+
.map(|(i, &share)| AmpChildDesc::new(i as u16, share))
27742769
.collect();
27752770

2776-
let child_descs: Vec<_> = amps.iter().map(|x| x.child_desc.clone()).collect();
27772771
let children = AmpChild::construct_amp_children(&child_descs, hash_algorithm);
2778-
for ((attempt, amp_data), child) in attempts.iter_mut().zip(&amps).zip(&children) {
2772+
for (index, (attempt, child)) in attempts.iter_mut().zip(&children).enumerate() {
27792773
let last_hop = attempt.route_hops.last_mut().expect("last hop");
2774+
let amp_data =
2775+
AmpPaymentData::new(payment_hash, total_count, child_descs[index].clone());
27802776
let mut custom_records = last_hop.custom_records.clone().unwrap_or_default();
27812777
amp_data.write(&mut custom_records);
27822778
last_hop.custom_records = Some(custom_records);

crates/fiber-lib/src/fiber/tests/types.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,11 @@ fn test_basic_mpp_custom_records() {
735735
fn test_amp_custom_records() {
736736
let mut payment_custom_records = PaymentCustomRecords::default();
737737
let parent_payment_hash = gen_rand_sha256_hash();
738-
let amp_record = AmpPaymentData::new(parent_payment_hash, 0, 3, AmpSecret::random());
738+
let amp_record = AmpPaymentData::new(
739+
parent_payment_hash,
740+
3,
741+
AmpChildDesc::new(0, AmpSecret::random()),
742+
);
739743
amp_record.write(&mut payment_custom_records);
740744

741745
let new_amp_record = AmpPaymentData::read(&payment_custom_records).unwrap();

crates/fiber-lib/src/fiber/types.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3795,12 +3795,11 @@ pub struct AmpPaymentData {
37953795
impl AmpPaymentData {
37963796
pub const CUSTOM_RECORD_KEY: u32 = USER_CUSTOM_RECORDS_MAX_INDEX + 2;
37973797

3798-
pub fn new(payment_hash: Hash256, index: u16, total_amp_count: u16, secret: AmpSecret) -> Self {
3799-
debug_assert!(index < total_amp_count);
3798+
pub fn new(payment_hash: Hash256, total_amp_count: u16, child_desc: AmpChildDesc) -> Self {
38003799
Self {
38013800
payment_hash,
38023801
total_amp_count,
3803-
child_desc: AmpChildDesc::new(index, secret),
3802+
child_desc,
38043803
}
38053804
}
38063805

@@ -3837,9 +3836,8 @@ impl AmpPaymentData {
38373836
let secret = AmpSecret::new(data[36..68].try_into().unwrap());
38383837
Some(Self::new(
38393838
Hash256::from(parent_hash),
3840-
index,
38413839
total_amp_count,
3842-
secret,
3840+
AmpChildDesc::new(index, secret),
38433841
))
38443842
})
38453843
}

0 commit comments

Comments
 (0)