Skip to content

Commit 4e24a41

Browse files
committed
Add utils to persist scorer in BackgroundProcessor
1 parent 62edee5 commit 4e24a41

File tree

3 files changed

+114
-13
lines changed

3 files changed

+114
-13
lines changed

lightning-background-processor/src/lib.rs

+90-10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use lightning::ln::channelmanager::ChannelManager;
1818
use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
1919
use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescriptor};
2020
use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
21+
use lightning::routing::scoring::WriteableScore;
2122
use lightning::util::events::{Event, EventHandler, EventsProvider};
2223
use lightning::util::logger::Logger;
2324
use lightning::util::persist::Persister;
@@ -151,6 +152,7 @@ impl BackgroundProcessor {
151152
/// [`NetworkGraph`]: lightning::routing::network_graph::NetworkGraph
152153
/// [`NetworkGraph::write`]: lightning::routing::network_graph::NetworkGraph#impl-Writeable
153154
pub fn start<
155+
'a,
154156
Signer: 'static + Sign,
155157
CA: 'static + Deref + Send + Sync,
156158
CF: 'static + Deref + Send + Sync,
@@ -171,9 +173,10 @@ impl BackgroundProcessor {
171173
NG: 'static + Deref<Target = NetGraphMsgHandler<G, CA, L>> + Send + Sync,
172174
UMH: 'static + Deref + Send + Sync,
173175
PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L, UMH>> + Send + Sync,
176+
S: 'static + Deref + Send + Sync,
174177
>(
175178
persister: PS, event_handler: EH, chain_monitor: M, channel_manager: CM,
176-
net_graph_msg_handler: Option<NG>, peer_manager: PM, logger: L
179+
net_graph_msg_handler: Option<NG>, peer_manager: PM, logger: L, scorer: Option<S>
177180
) -> Self
178181
where
179182
CA::Target: 'static + chain::Access,
@@ -187,7 +190,8 @@ impl BackgroundProcessor {
187190
CMH::Target: 'static + ChannelMessageHandler,
188191
RMH::Target: 'static + RoutingMessageHandler,
189192
UMH::Target: 'static + CustomMessageHandler,
190-
PS::Target: 'static + Persister<Signer, CW, T, K, F, L>
193+
PS::Target: 'static + Persister<'a, Signer, CW, T, K, F, L, S>,
194+
S::Target: WriteableScore<'a>,
191195
{
192196
let stop_thread = Arc::new(AtomicBool::new(false));
193197
let stop_thread_clone = stop_thread.clone();
@@ -277,6 +281,11 @@ impl BackgroundProcessor {
277281
last_prune_call = Instant::now();
278282
have_pruned = true;
279283
}
284+
if let Some(ref scorer) = scorer {
285+
if let Err(e) = persister.persist_scorer(&scorer) {
286+
log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
287+
}
288+
}
280289
}
281290
}
282291

@@ -285,10 +294,16 @@ impl BackgroundProcessor {
285294
// ChannelMonitor update(s) persisted without a corresponding ChannelManager update.
286295
persister.persist_manager(&*channel_manager)?;
287296

297+
// Persist Scorer on exit
298+
if let Some(ref scorer) = scorer {
299+
persister.persist_scorer(&scorer)?;
300+
}
301+
288302
// Persist NetworkGraph on exit
289303
if let Some(ref handler) = net_graph_msg_handler {
290304
persister.persist_graph(handler.network_graph())?;
291305
}
306+
292307
Ok(())
293308
});
294309
Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) }
@@ -411,12 +426,13 @@ mod tests {
411426
graph_error: Option<(std::io::ErrorKind, &'static str)>,
412427
manager_error: Option<(std::io::ErrorKind, &'static str)>,
413428
filesystem_persister: FilesystemPersister,
429+
scorer_error: Option<(std::io::ErrorKind, &'static str)>
414430
}
415431

416432
impl Persister {
417433
fn new(data_dir: String) -> Self {
418434
let filesystem_persister = FilesystemPersister::new(data_dir.clone());
419-
Self { graph_error: None, manager_error: None, filesystem_persister }
435+
Self { graph_error: None, manager_error: None, scorer_error: None, filesystem_persister }
420436
}
421437

422438
fn with_graph_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
@@ -426,6 +442,10 @@ mod tests {
426442
fn with_manager_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
427443
Self { manager_error: Some((error, message)), ..self }
428444
}
445+
446+
fn with_scorer_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
447+
Self { scorer_error: Some((error, message)), ..self }
448+
}
429449
}
430450

431451
impl KVStorePersister for Persister {
@@ -442,6 +462,12 @@ mod tests {
442462
}
443463
}
444464

465+
if key == "scorer" {
466+
if let Some((error, message)) = self.scorer_error {
467+
return Err(std::io::Error::new(error, message))
468+
}
469+
}
470+
445471
self.filesystem_persister.persist(key, object)
446472
}
447473
}
@@ -571,7 +597,8 @@ mod tests {
571597
let data_dir = nodes[0].persister.get_data_dir();
572598
let persister = Arc::new(Persister::new(data_dir));
573599
let event_handler = |_: &_| {};
574-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
600+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
601+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer.clone()));
575602

576603
macro_rules! check_persisted_data {
577604
($node: expr, $filepath: expr) => {
@@ -597,6 +624,30 @@ mod tests {
597624
}
598625
}
599626

627+
macro_rules! check_mutex_persisted_data {
628+
($node: expr, $filepath: expr) => {
629+
let mut expected_bytes = Vec::new();
630+
loop {
631+
expected_bytes.clear();
632+
match $node.lock().unwrap().write(&mut expected_bytes) {
633+
Ok(_) => {
634+
match std::fs::read($filepath) {
635+
Ok(bytes) => {
636+
if bytes == expected_bytes {
637+
break
638+
} else {
639+
continue
640+
}
641+
},
642+
Err(_) => continue
643+
}
644+
},
645+
Err(e) => panic!("Unexpected error: {}", e)
646+
}
647+
}
648+
}
649+
}
650+
600651
// Check that the initial channel manager data is persisted as expected.
601652
let filepath = get_full_filepath("test_background_processor_persister_0".to_string(), "manager".to_string());
602653
check_persisted_data!(nodes[0].node, filepath.clone());
@@ -621,6 +672,10 @@ mod tests {
621672
check_persisted_data!(network_graph, filepath.clone());
622673
}
623674

675+
// Check scorer is persisted
676+
let filepath = get_full_filepath("test_background_processor_persister_0".to_string(), "scorer".to_string());
677+
check_mutex_persisted_data!(scorer, filepath.clone());
678+
624679
assert!(bg_processor.stop().is_ok());
625680
}
626681

@@ -632,7 +687,8 @@ mod tests {
632687
let data_dir = nodes[0].persister.get_data_dir();
633688
let persister = Arc::new(Persister::new(data_dir));
634689
let event_handler = |_: &_| {};
635-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
690+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
691+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
636692
loop {
637693
let log_entries = nodes[0].logger.lines.lock().unwrap();
638694
let desired_log = "Calling ChannelManager's timer_tick_occurred".to_string();
@@ -655,7 +711,8 @@ mod tests {
655711
let data_dir = nodes[0].persister.get_data_dir();
656712
let persister = Arc::new(Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test"));
657713
let event_handler = |_: &_| {};
658-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
714+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
715+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
659716
match bg_processor.join() {
660717
Ok(_) => panic!("Expected error persisting manager"),
661718
Err(e) => {
@@ -672,7 +729,8 @@ mod tests {
672729
let data_dir = nodes[0].persister.get_data_dir();
673730
let persister = Arc::new(Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test"));
674731
let event_handler = |_: &_| {};
675-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
732+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
733+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
676734

677735
match bg_processor.stop() {
678736
Ok(_) => panic!("Expected error persisting network graph"),
@@ -683,6 +741,25 @@ mod tests {
683741
}
684742
}
685743

744+
#[test]
745+
fn test_scorer_persist_error() {
746+
// Test that if we encounter an error during scorer persistence, an error gets returned.
747+
let nodes = create_nodes(2, "test_persist_scorer_error".to_string());
748+
let data_dir = nodes[0].persister.get_data_dir();
749+
let persister = Arc::new(Persister::new(data_dir).with_scorer_error(std::io::ErrorKind::Other, "test"));
750+
let event_handler = |_: &_| {};
751+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
752+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
753+
754+
match bg_processor.stop() {
755+
Ok(_) => panic!("Expected error persisting scorer"),
756+
Err(e) => {
757+
assert_eq!(e.kind(), std::io::ErrorKind::Other);
758+
assert_eq!(e.get_ref().unwrap().to_string(), "test");
759+
},
760+
}
761+
}
762+
686763
#[test]
687764
fn test_background_event_handling() {
688765
let mut nodes = create_nodes(2, "test_background_event_handling".to_string());
@@ -695,7 +772,8 @@ mod tests {
695772
let event_handler = move |event: &Event| {
696773
sender.send(handle_funding_generation_ready!(event, channel_value)).unwrap();
697774
};
698-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
775+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
776+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
699777

700778
// Open a channel and check that the FundingGenerationReady event was handled.
701779
begin_open_channel!(nodes[0], nodes[1], channel_value);
@@ -720,7 +798,8 @@ mod tests {
720798
let (sender, receiver) = std::sync::mpsc::sync_channel(1);
721799
let event_handler = move |event: &Event| sender.send(event.clone()).unwrap();
722800
let persister = Arc::new(Persister::new(data_dir));
723-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
801+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
802+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
724803

725804
// Force close the channel and check that the SpendableOutputs event was handled.
726805
nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
@@ -751,7 +830,8 @@ mod tests {
751830
let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes);
752831
let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
753832
let event_handler = Arc::clone(&invoice_payer);
754-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
833+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
834+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(scorer));
755835
assert!(bg_processor.stop().is_ok());
756836
}
757837
}

lightning/src/routing/scoring.rs

+10
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ pub trait LockableScore<'a> {
137137
fn lock(&'a self) -> Self::Locked;
138138
}
139139

140+
/// Refers to a scorer that is accessible under lock and also writeable to disk
141+
///
142+
/// We need this trait to be able to pass in a scorer to `lightning-background-processor` that will enable us to
143+
/// use the Persister to persist it.
144+
pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
145+
146+
impl<'a, T> WriteableScore<'a> for T
147+
where T: LockableScore<'a> + Writeable
148+
{}
149+
140150
/// (C-not exported)
141151
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
142152
type Locked = MutexGuard<'a, T>;

lightning/src/util/persist.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use core::ops::Deref;
1212
use bitcoin::hashes::hex::ToHex;
1313
use io::{self};
14+
use routing::scoring::WriteableScore;
1415

1516
use crate::{chain::{keysinterface::{Sign, KeysInterface}, self, transaction::{OutPoint}, chaininterface::{BroadcasterInterface, FeeEstimator}, chainmonitor::{Persist, MonitorUpdateId}, channelmonitor::{ChannelMonitor, ChannelMonitorUpdate}}, ln::channelmanager::ChannelManager, routing::network_graph::NetworkGraph};
1617
use super::{logger::Logger, ser::Writeable};
@@ -24,27 +25,32 @@ pub trait KVStorePersister {
2425
fn persist<W: Writeable>(&self, key: &str, object: &W) -> io::Result<()>;
2526
}
2627

27-
/// Trait that handles persisting a [`ChannelManager`] and [`NetworkGraph`] to disk.
28-
pub trait Persister<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
28+
/// Trait that handles persisting a [`ChannelManager`], [`NetworkGraph`], and [`WriteableScore`] to disk.
29+
pub trait Persister<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref, S:Deref>
2930
where M::Target: 'static + chain::Watch<Signer>,
3031
T::Target: 'static + BroadcasterInterface,
3132
K::Target: 'static + KeysInterface<Signer = Signer>,
3233
F::Target: 'static + FeeEstimator,
3334
L::Target: 'static + Logger,
35+
S::Target: 'static + WriteableScore<'a>
3436
{
3537
/// Persist the given ['ChannelManager'] to disk, returning an error if persistence failed.
3638
fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), io::Error>;
3739

3840
/// Persist the given [`NetworkGraph`] to disk, returning an error if persistence failed.
3941
fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), io::Error>;
42+
43+
/// Persist the given [`WriteableScore`] to disk, returning an error if persistence failed.
44+
fn persist_scorer(&self, scorer: &S) -> Result<(), io::Error>;
4045
}
4146

42-
impl<A: KVStorePersister, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Persister<Signer, M, T, K, F, L> for A
47+
impl<'a, A: KVStorePersister, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref, S: Deref> Persister<'a, Signer, M, T, K, F, L, S> for A
4348
where M::Target: 'static + chain::Watch<Signer>,
4449
T::Target: 'static + BroadcasterInterface,
4550
K::Target: 'static + KeysInterface<Signer = Signer>,
4651
F::Target: 'static + FeeEstimator,
4752
L::Target: 'static + Logger,
53+
S::Target: 'static + WriteableScore<'a> + Sized
4854
{
4955
/// Persist the given ['ChannelManager'] to disk, returning an error if persistence failed.
5056
fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), io::Error> {
@@ -55,6 +61,11 @@ impl<A: KVStorePersister, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
5561
fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), io::Error> {
5662
self.persist("network_graph", network_graph)
5763
}
64+
65+
/// Persist the given [`WriteableScore`] to disk, returning an error if persistence failed.
66+
fn persist_scorer(&self, scorer: &S) -> Result<(), io::Error> {
67+
self.persist("scorer", scorer.deref())
68+
}
5869
}
5970

6071
impl<ChannelSigner: Sign, K: KVStorePersister> Persist<ChannelSigner> for K {

0 commit comments

Comments
 (0)