Skip to content

Commit 26118a7

Browse files
committed
Write peer storage to DB
1 parent db27c71 commit 26118a7

File tree

10 files changed

+179
-27
lines changed

10 files changed

+179
-27
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
9292
revokedHtlcInfoCleanerConfig: RevokedHtlcInfoCleaner.Config,
9393
willFundRates_opt: Option[LiquidityAds.WillFundRates],
9494
peerWakeUpConfig: PeerReadyNotifier.WakeUpConfig,
95-
onTheFlyFundingConfig: OnTheFlyFunding.Config) {
95+
onTheFlyFundingConfig: OnTheFlyFunding.Config,
96+
peerStorageWriteDelayMax: FiniteDuration) {
9697
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey
9798

9899
val nodeId: PublicKey = nodeKeyManager.nodeId
@@ -678,6 +679,7 @@ object NodeParams extends Logging {
678679
onTheFlyFundingConfig = OnTheFlyFunding.Config(
679680
proposalTimeout = FiniteDuration(config.getDuration("on-the-fly-funding.proposal-timeout").getSeconds, TimeUnit.SECONDS),
680681
),
682+
peerStorageWriteDelayMax = 1 minute,
681683
)
682684
}
683685
}

eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import fr.acinq.eclair.router.Router
1515
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAddress, NodeAnnouncement}
1616
import fr.acinq.eclair.{CltvExpiry, MilliSatoshi, Paginated, RealShortChannelId, ShortChannelId, TimestampMilli}
1717
import grizzled.slf4j.Logging
18+
import scodec.bits.ByteVector
1819

1920
import java.io.File
2021
import java.util.UUID
@@ -292,6 +293,16 @@ case class DualPeersDb(primary: PeersDb, secondary: PeersDb) extends PeersDb {
292293
runAsync(secondary.getRelayFees(nodeId))
293294
primary.getRelayFees(nodeId)
294295
}
296+
297+
override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = {
298+
runAsync(secondary.updateStorage(nodeId, data))
299+
primary.updateStorage(nodeId, data)
300+
}
301+
302+
override def getStorage(nodeId: PublicKey): Option[ByteVector] = {
303+
runAsync(secondary.getStorage(nodeId))
304+
primary.getStorage(nodeId)
305+
}
295306
}
296307

297308
case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends PaymentsDb {

eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package fr.acinq.eclair.db
1919
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
2020
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
2121
import fr.acinq.eclair.wire.protocol.NodeAddress
22+
import scodec.bits.ByteVector
2223

2324
trait PeersDb {
2425

@@ -34,4 +35,8 @@ trait PeersDb {
3435

3536
def getRelayFees(nodeId: PublicKey): Option[RelayFees]
3637

38+
def updateStorage(nodeId: PublicKey, data: ByteVector): Unit
39+
40+
def getStorage(nodeId: PublicKey): Option[ByteVector]
41+
3742
}

eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock
2626
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
2727
import fr.acinq.eclair.wire.protocol._
2828
import grizzled.slf4j.Logging
29-
import scodec.bits.BitVector
29+
import scodec.bits.{BitVector, ByteVector}
3030

3131
import java.sql.Statement
3232
import javax.sql.DataSource
3333

3434
object PgPeersDb {
3535
val DB_NAME = "peers"
36-
val CURRENT_VERSION = 3
36+
val CURRENT_VERSION = 4
3737
}
3838

3939
class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logging {
@@ -54,20 +54,28 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
5454
statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)")
5555
}
5656

57+
def migration34(statement: Statement): Unit = {
58+
statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
59+
}
60+
5761
using(pg.createStatement()) { statement =>
5862
getVersion(statement, DB_NAME) match {
5963
case None =>
6064
statement.executeUpdate("CREATE SCHEMA IF NOT EXISTS local")
61-
statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
65+
statement.executeUpdate("CREATE TABLE local.peers (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, storage BYTEA)")
6266
statement.executeUpdate("CREATE TABLE local.relay_fees (node_id TEXT NOT NULL PRIMARY KEY, fee_base_msat BIGINT NOT NULL, fee_proportional_millionths BIGINT NOT NULL)")
63-
case Some(v@(1 | 2)) =>
67+
statement.executeUpdate("CREATE TABLE local.peer_storage (node_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL)")
68+
case Some(v@(1 | 2 | 3)) =>
6469
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
6570
if (v < 2) {
6671
migration12(statement)
6772
}
6873
if (v < 3) {
6974
migration23(statement)
7075
}
76+
if (v < 4) {
77+
migration34(statement)
78+
}
7179
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
7280
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
7381
}
@@ -98,6 +106,10 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
98106
statement.setString(1, nodeId.value.toHex)
99107
statement.executeUpdate()
100108
}
109+
using(pg.prepareStatement("DELETE FROM local.peer_storage WHERE node_id = ?")) { statement =>
110+
statement.setString(1, nodeId.value.toHex)
111+
statement.executeUpdate()
112+
}
101113
}
102114
}
103115

@@ -155,4 +167,31 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb with Logg
155167
}
156168
}
157169
}
170+
171+
override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Postgres) {
172+
withLock { pg =>
173+
using(pg.prepareStatement(
174+
"""
175+
INSERT INTO local.peer_storage (node_id, data)
176+
VALUES (?, ?)
177+
ON CONFLICT (node_id)
178+
DO UPDATE SET data = EXCLUDED.data
179+
""")) { statement =>
180+
statement.setString(1, nodeId.value.toHex)
181+
statement.setBytes(2, data.toArray)
182+
statement.executeUpdate()
183+
}
184+
}
185+
}
186+
187+
override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Postgres) {
188+
withLock { pg =>
189+
using(pg.prepareStatement("SELECT data FROM local.peer_storage WHERE node_id = ?")) { statement =>
190+
statement.setString(1, nodeId.value.toHex)
191+
statement.executeQuery()
192+
.headOption
193+
.map(rs => ByteVector(rs.getBytes("data")))
194+
}
195+
}
196+
}
158197
}

eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, setVersion, using}
2626
import fr.acinq.eclair.payment.relay.Relayer.RelayFees
2727
import fr.acinq.eclair.wire.protocol._
2828
import grizzled.slf4j.Logging
29-
import scodec.bits.BitVector
29+
import scodec.bits.{BitVector, ByteVector}
3030

3131
import java.sql.{Connection, Statement}
3232

3333
object SqlitePeersDb {
3434
val DB_NAME = "peers"
35-
val CURRENT_VERSION = 2
35+
val CURRENT_VERSION = 3
3636
}
3737

3838
class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
@@ -46,13 +46,23 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
4646
statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)")
4747
}
4848

49+
def migration23(statement: Statement): Unit = {
50+
statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)")
51+
}
52+
4953
getVersion(statement, DB_NAME) match {
5054
case None =>
5155
statement.executeUpdate("CREATE TABLE peers (node_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL)")
5256
statement.executeUpdate("CREATE TABLE relay_fees (node_id BLOB NOT NULL PRIMARY KEY, fee_base_msat INTEGER NOT NULL, fee_proportional_millionths INTEGER NOT NULL)")
53-
case Some(v@1) =>
57+
statement.executeUpdate("CREATE TABLE peer_storage (node_id BLOB NOT NULL PRIMARY KEY, data NOT NULL)")
58+
case Some(v@(1 | 2)) =>
5459
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
55-
migration12(statement)
60+
if (v < 2) {
61+
migration12(statement)
62+
}
63+
if (v < 3) {
64+
migration23(statement)
65+
}
5666
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
5767
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
5868
}
@@ -128,4 +138,27 @@ class SqlitePeersDb(val sqlite: Connection) extends PeersDb with Logging {
128138
)
129139
}
130140
}
141+
142+
override def updateStorage(nodeId: PublicKey, data: ByteVector): Unit = withMetrics("peers/update-storage", DbBackends.Sqlite) {
143+
using(sqlite.prepareStatement("UPDATE peer_storage SET data = ? WHERE node_id = ?")) { update =>
144+
update.setBytes(1, data.toArray)
145+
update.setBytes(2, nodeId.value.toArray)
146+
if (update.executeUpdate() == 0) {
147+
using(sqlite.prepareStatement("INSERT INTO peer_storage VALUES (?, ?)")) { statement =>
148+
statement.setBytes(1, nodeId.value.toArray)
149+
statement.setBytes(2, data.toArray)
150+
statement.executeUpdate()
151+
}
152+
}
153+
}
154+
}
155+
156+
override def getStorage(nodeId: PublicKey): Option[ByteVector] = withMetrics("peers/get-storage", DbBackends.Sqlite) {
157+
using(sqlite.prepareStatement("SELECT data FROM peer_storage WHERE node_id = ?")) { statement =>
158+
statement.setBytes(1, nodeId.value.toArray)
159+
statement.executeQuery()
160+
.headOption
161+
.map(rs => ByteVector(rs.getBytes("data")))
162+
}
163+
}
131164
}

eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
4444
import fr.acinq.eclair.router.Router
4545
import fr.acinq.eclair.wire.protocol
4646
import fr.acinq.eclair.wire.protocol.FailureMessageCodecs.createBadOnionFailure
47-
import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
48-
import fr.acinq.eclair.wire.protocol.LiquidityAds.PaymentDetails
49-
import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
47+
import fr.acinq.eclair.wire.protocol.{AddFeeCredit, ChannelTlv, CurrentFeeCredit, Error, HasChannelId, HasTemporaryChannelId, LightningMessage, LiquidityAds, NodeAddress, OnTheFlyFundingFailureMessage, OnionMessage, OnionRoutingPacket, PeerStorageRetrieval, PeerStorageStore, RoutingMessage, SpliceInit, TlvStream, UnknownMessage, Warning, WillAddHtlc, WillFailHtlc, WillFailMalformedHtlc}
5048
import scodec.bits.ByteVector
5149

5250
/**
@@ -87,7 +85,7 @@ class Peer(val nodeParams: NodeParams,
8785
FinalChannelId(state.channelId) -> channel
8886
}.toMap
8987
context.system.eventStream.publish(PeerCreated(self, remoteNodeId))
90-
goto(DISCONNECTED) using DisconnectedData(channels, None) // when we restart, we will attempt to reconnect right away, but then we'll wait
88+
goto(DISCONNECTED) using DisconnectedData(channels, PeerStorage(nodeParams.db.peers.getStorage(remoteNodeId), written = true, TimestampMilli.min)) // when we restart, we will attempt to reconnect right away, but then we'll wait
9189
}
9290

9391
when(DISCONNECTED) {
@@ -510,7 +508,19 @@ class Peer(val nodeParams: NodeParams,
510508
stay()
511509

512510
case Event(store: PeerStorageStore, d: ConnectedData) if nodeParams.features.hasFeature(Features.ProvideStorage) && d.channels.nonEmpty =>
513-
stay() using d.copy(peerStorage = Some(store.blob))
511+
val timeSinceLastWrite = TimestampMilli.now() - d.peerStorage.lastWrite
512+
val peerStorage = if (timeSinceLastWrite >= nodeParams.peerStorageWriteDelayMax) {
513+
nodeParams.db.peers.updateStorage(remoteNodeId, store.blob)
514+
PeerStorage(Some(store.blob), written = true, TimestampMilli.now())
515+
} else {
516+
startSingleTimer("peer-storage-write", WritePeerStorage, nodeParams.peerStorageWriteDelayMax - timeSinceLastWrite)
517+
PeerStorage(Some(store.blob), written = false, d.peerStorage.lastWrite)
518+
}
519+
stay() using d.copy(peerStorage = peerStorage)
520+
521+
case Event(WritePeerStorage, d: ConnectedData) =>
522+
d.peerStorage.data.foreach(nodeParams.db.peers.updateStorage(remoteNodeId, _))
523+
stay() using d.copy(peerStorage = PeerStorage(d.peerStorage.data, written = true, TimestampMilli.now()))
514524

515525
case Event(unhandledMsg: LightningMessage, _) =>
516526
log.warning("ignoring message {}", unhandledMsg)
@@ -722,7 +732,7 @@ class Peer(val nodeParams: NodeParams,
722732
context.system.eventStream.publish(PeerDisconnected(self, remoteNodeId))
723733
}
724734

725-
private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: Option[ByteVector]): State = {
735+
private def gotoConnected(connectionReady: PeerConnection.ConnectionReady, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage): State = {
726736
require(remoteNodeId == connectionReady.remoteNodeId, s"invalid nodeId: $remoteNodeId != ${connectionReady.remoteNodeId}")
727737
log.debug("got authenticated connection to address {}", connectionReady.address)
728738

@@ -733,7 +743,7 @@ class Peer(val nodeParams: NodeParams,
733743
}
734744

735745
// If we have some data stored from our peer, we send it to them before doing anything else.
736-
peerStorage.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_))
746+
peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_))
737747

738748
// let's bring existing/requested channels online
739749
channels.values.toSet[ActorRef].foreach(_ ! INPUT_RECONNECTED(connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit)) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)
@@ -886,16 +896,18 @@ object Peer {
886896
case class TemporaryChannelId(id: ByteVector32) extends ChannelId
887897
case class FinalChannelId(id: ByteVector32) extends ChannelId
888898

899+
case class PeerStorage(data: Option[ByteVector], written: Boolean, lastWrite: TimestampMilli)
900+
889901
sealed trait Data {
890902
def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef]
891-
def peerStorage: Option[ByteVector]
903+
def peerStorage: PeerStorage
892904
}
893905
case object Nothing extends Data {
894906
override def channels = Map.empty
895-
override def peerStorage: Option[ByteVector] = None
907+
override def peerStorage: PeerStorage = PeerStorage(None, written = true, TimestampMilli.min)
896908
}
897-
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: Option[ByteVector]) extends Data
898-
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], peerStorage: Option[ByteVector]) extends Data {
909+
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], peerStorage: PeerStorage) extends Data
910+
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], peerStorage: PeerStorage) extends Data {
899911
val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit)
900912
def localFeatures: Features[InitFeature] = localInit.features
901913
def remoteFeatures: Features[InitFeature] = remoteInit.features
@@ -1006,5 +1018,7 @@ object Peer {
10061018
case class RelayOnionMessage(messageId: ByteVector32, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]])
10071019

10081020
case class RelayUnknownMessage(unknownMessage: UnknownMessage)
1021+
1022+
case object WritePeerStorage
10091023
// @formatter:on
10101024
}

eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ object TestConstants {
106106
Features.PaymentMetadata -> FeatureSupport.Optional,
107107
Features.RouteBlinding -> FeatureSupport.Optional,
108108
Features.StaticRemoteKey -> FeatureSupport.Mandatory,
109+
Features.ProvideStorage -> FeatureSupport.Optional,
109110
),
110111
unknown = Set(UnknownFeature(TestFeature.optional))
111112
),
@@ -238,6 +239,7 @@ object TestConstants {
238239
willFundRates_opt = Some(defaultLiquidityRates),
239240
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
240241
onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds),
242+
peerStorageWriteDelayMax = 5 seconds,
241243
)
242244

243245
def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(
@@ -412,6 +414,7 @@ object TestConstants {
412414
willFundRates_opt = Some(defaultLiquidityRates),
413415
peerWakeUpConfig = PeerReadyNotifier.WakeUpConfig(enabled = false, timeout = 30 seconds),
414416
onTheFlyFundingConfig = OnTheFlyFunding.Config(proposalTimeout = 90 seconds),
417+
peerStorageWriteDelayMax = 5 seconds,
415418
)
416419

417420
def channelParams: LocalParams = OpenChannelInterceptor.makeChannelParams(

eclair-core/src/test/scala/fr/acinq/eclair/db/PeersDbSpec.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import fr.acinq.eclair.payment.relay.Relayer.RelayFees
2424
import fr.acinq.eclair._
2525
import fr.acinq.eclair.wire.protocol.{NodeAddress, Tor2, Tor3}
2626
import org.scalatest.funsuite.AnyFunSuite
27+
import scodec.bits.HexStringSyntax
2728

2829
import java.util.concurrent.Executors
2930
import scala.concurrent.duration._
@@ -107,4 +108,24 @@ class PeersDbSpec extends AnyFunSuite {
107108
}
108109
}
109110

111+
test("peer storage") {
112+
forAllDbs { dbs =>
113+
val db = dbs.peers
114+
115+
val a = randomKey().publicKey
116+
val b = randomKey().publicKey
117+
118+
assert(db.getStorage(a) == None)
119+
assert(db.getStorage(b) == None)
120+
db.updateStorage(a, hex"012345")
121+
assert(db.getStorage(a) == Some(hex"012345"))
122+
assert(db.getStorage(b) == None)
123+
db.updateStorage(a, hex"6789")
124+
assert(db.getStorage(a) == Some(hex"6789"))
125+
assert(db.getStorage(b) == None)
126+
db.updateStorage(b, hex"abcd")
127+
assert(db.getStorage(a) == Some(hex"6789"))
128+
assert(db.getStorage(b) == Some(hex"abcd"))
129+
}
130+
}
110131
}

0 commit comments

Comments
 (0)