Skip to content

Commit 71fb4fa

Browse files
committed
Batch incoming commit_sig in PeerConnection
We move the incoming `commit_sig` batching logic outside of the channel and into the `PeerConnection` instead. This slightly simplifies the channel FSM and its tests, since the `PeerConnection` actor is simpler. We unfortunately cannot easily do this in the `TransportHandler` because of our buffered read of the encrypted messages, which may split batches and make it more complex to correctly group messages.
1 parent 5dbd63a commit 71fb4fa

File tree

7 files changed

+242
-145
lines changed

7 files changed

+242
-145
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ case class Commitments(params: ChannelParams,
11591159
case ChannelSpendSignature.IndividualSignature(latestRemoteSig) => latestRemoteSig == commitSig.signature
11601160
case ChannelSpendSignature.PartialSignatureWithNonce(_, _) => ???
11611161
}
1162-
params.channelFeatures.hasFeature(Features.DualFunding) && commitSig.batchSize == 1 && isLatestSig
1162+
params.channelFeatures.hasFeature(Features.DualFunding) && isLatestSig
11631163
}
11641164

11651165
def localFundingSigs(fundingTxId: TxId): Option[TxSignatures] = {

eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/Channel.scala

Lines changed: 91 additions & 113 deletions
Large diffs are not rendered by default.

eclair-core/src/main/scala/fr/acinq/eclair/io/PeerConnection.scala

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ package fr.acinq.eclair.io
1818

1919
import akka.actor.{ActorRef, FSM, OneForOneStrategy, PoisonPill, Props, Stash, SupervisorStrategy, Terminated}
2020
import akka.event.Logging.MDC
21-
import fr.acinq.bitcoin.scalacompat.BlockHash
2221
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
22+
import fr.acinq.bitcoin.scalacompat.{BlockHash, ByteVector32}
2323
import fr.acinq.eclair.Logs.LogCategory
2424
import fr.acinq.eclair.crypto.Noise.KeyPair
2525
import fr.acinq.eclair.crypto.TransportHandler
@@ -343,10 +343,48 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A
343343
stay()
344344

345345
case Event(msg: LightningMessage, d: ConnectedData) =>
346-
// we acknowledge and pass all other messages to the peer
346+
// We immediately acknowledge all other messages.
347347
d.transport ! TransportHandler.ReadAck(msg)
348-
d.peer ! msg
349-
stay()
348+
// We immediately forward messages to the peer, unless they are part of a batch, in which case we wait to
349+
// receive the whole batch before forwarding.
350+
msg match {
351+
case msg: CommitSig =>
352+
msg.tlvStream.get[CommitSigTlv.BatchTlv].map(_.size) match {
353+
case Some(batchSize) if batchSize > 25 =>
354+
log.warning("received legacy batch of commit_sig exceeding our threshold ({} > 25), processing messages individually", batchSize)
355+
// We don't want peers to be able to exhaust our memory by sending batches of dummy messages that we keep in RAM.
356+
d.peer ! msg
357+
stay()
358+
case Some(batchSize) if batchSize > 1 =>
359+
d.legacyCommitSigBatch_opt match {
360+
case Some(pending) if pending.channelId != msg.channelId || pending.batchSize != batchSize =>
361+
log.warning("received invalid commit_sig batch while a different batch isn't complete")
362+
// This should never happen, otherwise it will likely lead to a force-close.
363+
d.peer ! CommitSigBatch(pending.received)
364+
stay() using d.copy(legacyCommitSigBatch_opt = Some(PendingCommitSigBatch(msg.channelId, batchSize, Seq(msg))))
365+
case Some(pending) =>
366+
val received1 = pending.received :+ msg
367+
if (received1.size == batchSize) {
368+
log.debug("received last commit_sig in legacy batch for channel_id={}", msg.channelId)
369+
d.peer ! CommitSigBatch(received1)
370+
stay() using d.copy(legacyCommitSigBatch_opt = None)
371+
} else {
372+
log.debug("received commit_sig {}/{} in legacy batch for channel_id={}", received1.size, batchSize, msg.channelId)
373+
stay() using d.copy(legacyCommitSigBatch_opt = Some(pending.copy(received = received1)))
374+
}
375+
case None =>
376+
log.debug("received first commit_sig in legacy batch of size {} for channel_id={}", batchSize, msg.channelId)
377+
stay() using d.copy(legacyCommitSigBatch_opt = Some(PendingCommitSigBatch(msg.channelId, batchSize, Seq(msg))))
378+
}
379+
case _ =>
380+
log.debug("received individual commit_sig for channel_id={}", msg.channelId)
381+
d.peer ! msg
382+
stay()
383+
}
384+
case _ =>
385+
d.peer ! msg
386+
stay()
387+
}
350388

351389
case Event(readAck: TransportHandler.ReadAck, d: ConnectedData) =>
352390
// we just forward acks to the transport (e.g. from the router)
@@ -566,8 +604,19 @@ object PeerConnection {
566604
case class AuthenticatingData(pendingAuth: PendingAuth, transport: ActorRef, isPersistent: Boolean) extends Data with HasTransport
567605
case class BeforeInitData(remoteNodeId: PublicKey, pendingAuth: PendingAuth, transport: ActorRef, isPersistent: Boolean) extends Data with HasTransport
568606
case class InitializingData(chainHash: BlockHash, pendingAuth: PendingAuth, remoteNodeId: PublicKey, transport: ActorRef, peer: ActorRef, localInit: protocol.Init, doSync: Boolean, isPersistent: Boolean) extends Data with HasTransport
569-
case class ConnectedData(chainHash: BlockHash, remoteNodeId: PublicKey, transport: ActorRef, peer: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, rebroadcastDelay: FiniteDuration, gossipTimestampFilter: Option[GossipTimestampFilter] = None, behavior: Behavior = Behavior(), expectedPong_opt: Option[ExpectedPong] = None, isPersistent: Boolean) extends Data with HasTransport
570-
607+
case class ConnectedData(chainHash: BlockHash,
608+
remoteNodeId: PublicKey,
609+
transport: ActorRef,
610+
peer: ActorRef,
611+
localInit: protocol.Init, remoteInit: protocol.Init,
612+
rebroadcastDelay: FiniteDuration,
613+
gossipTimestampFilter: Option[GossipTimestampFilter] = None,
614+
behavior: Behavior = Behavior(),
615+
expectedPong_opt: Option[ExpectedPong] = None,
616+
legacyCommitSigBatch_opt: Option[PendingCommitSigBatch] = None,
617+
isPersistent: Boolean) extends Data with HasTransport
618+
619+
case class PendingCommitSigBatch(channelId: ByteVector32, batchSize: Int, received: Seq[CommitSig])
571620
case class ExpectedPong(ping: Ping, timestamp: TimestampMilli = TimestampMilli.now())
572621
case class PingTimeout(ping: Ping)
573622

eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,7 @@ object CommitSigs {
440440
case class CommitSig(channelId: ByteVector32,
441441
signature: ByteVector64,
442442
htlcSignatures: List[ByteVector64],
443-
tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends CommitSigs {
444-
val batchSize: Int = tlvStream.get[CommitSigTlv.BatchTlv].map(_.size).getOrElse(1)
445-
}
443+
tlvStream: TlvStream[CommitSigTlv] = TlvStream.empty) extends CommitSigs
446444

447445
case class CommitSigBatch(messages: Seq[CommitSig]) extends CommitSigs {
448446
require(messages.map(_.channelId).toSet.size == 1, "commit_sig messages in a batch must be for the same channel")

eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -463,23 +463,17 @@ trait ChannelStateTestsBase extends Assertions with Eventually {
463463
val rHasChanges = r.stateData.asInstanceOf[ChannelDataWithCommitments].commitments.changes.localHasChanges
464464
s ! CMD_SIGN(Some(sender.ref))
465465
sender.expectMsgType[RES_SUCCESS[CMD_SIGN]]
466-
s2r.expectMsgType[CommitSigs] match {
467-
case sig: CommitSig => s2r.forward(r, sig)
468-
case batch: CommitSigBatch => batch.messages.foreach(sig => s2r.forward(r, sig))
469-
}
466+
s2r.expectMsgType[CommitSigs]
467+
s2r.forward(r)
470468
r2s.expectMsgType[RevokeAndAck]
471469
r2s.forward(s)
472-
r2s.expectMsgType[CommitSigs] match {
473-
case sig: CommitSig => r2s.forward(s, sig)
474-
case batch: CommitSigBatch => batch.messages.foreach(sig => r2s.forward(s, sig))
475-
}
470+
r2s.expectMsgType[CommitSigs]
471+
r2s.forward(s)
476472
s2r.expectMsgType[RevokeAndAck]
477473
s2r.forward(r)
478474
if (rHasChanges) {
479-
s2r.expectMsgType[CommitSigs] match {
480-
case sig: CommitSig => s2r.forward(r, sig)
481-
case batch: CommitSigBatch => batch.messages.foreach(sig => s2r.forward(r, sig))
482-
}
475+
s2r.expectMsgType[CommitSigs]
476+
s2r.forward(r)
483477
r2s.expectMsgType[RevokeAndAck]
484478
r2s.forward(s)
485479
eventually {

eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalSplicesStateSpec.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,12 +1519,12 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik
15191519
alice ! CMD_SIGN()
15201520
val sigsA = alice2bob.expectMsgType[CommitSigBatch]
15211521
assert(sigsA.batchSize == 2)
1522-
sigsA.messages.foreach(sig => alice2bob.forward(bob, sig))
1522+
alice2bob.forward(bob, sigsA)
15231523
bob2alice.expectMsgType[RevokeAndAck]
15241524
bob2alice.forward(alice)
15251525
val sigsB = bob2alice.expectMsgType[CommitSigBatch]
15261526
assert(sigsB.batchSize == 2)
1527-
sigsB.messages.foreach(sig => bob2alice.forward(alice, sig))
1527+
bob2alice.forward(alice, sigsB)
15281528
alice2bob.expectMsgType[RevokeAndAck]
15291529
alice2bob.forward(bob)
15301530
awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1))
@@ -1548,12 +1548,12 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik
15481548
alice2bob.forward(bob)
15491549
val sigsA = alice2bob.expectMsgType[CommitSigBatch]
15501550
assert(sigsA.batchSize == 2)
1551-
sigsA.messages.foreach(sig => alice2bob.forward(bob, sig))
1551+
alice2bob.forward(bob, sigsA)
15521552
bob2alice.expectMsgType[RevokeAndAck]
15531553
bob2alice.forward(alice)
15541554
val sigsB = bob2alice.expectMsgType[CommitSigBatch]
15551555
assert(sigsB.batchSize == 2)
1556-
sigsB.messages.foreach(sig => bob2alice.forward(alice, sig))
1556+
bob2alice.forward(alice, sigsB)
15571557
alice2bob.expectMsgType[RevokeAndAck]
15581558
alice2bob.forward(bob)
15591559
awaitCond(alice.stateData.asInstanceOf[DATA_NORMAL].commitments.active.forall(_.localCommit.spec.htlcs.size == 1))
@@ -1672,16 +1672,14 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik
16721672
alice ! CMD_SIGN()
16731673
val commitSigsAlice = alice2bob.expectMsgType[CommitSigBatch]
16741674
assert(commitSigsAlice.batchSize == 3)
1675-
alice2bob.forward(bob, commitSigsAlice.messages(0))
16761675
bob ! WatchPublishedTriggered(spliceTx2)
16771676
val spliceLockedBob = bob2alice.expectMsgType[SpliceLocked]
16781677
assert(spliceLockedBob.fundingTxId == spliceTx2.txid)
16791678
bob2alice.forward(alice, spliceLockedBob)
1680-
alice2bob.forward(bob, commitSigsAlice.messages(1))
1681-
alice2bob.forward(bob, commitSigsAlice.messages(2))
1679+
alice2bob.forward(bob, commitSigsAlice)
16821680
bob2alice.expectMsgType[RevokeAndAck]
16831681
bob2alice.forward(alice)
1684-
assert(bob2alice.expectMsgType[CommitSig].batchSize == 1)
1682+
bob2alice.expectMsgType[CommitSig]
16851683
bob2alice.forward(alice)
16861684
alice2bob.expectMsgType[RevokeAndAck]
16871685
alice2bob.forward(bob)
@@ -3335,13 +3333,13 @@ class NormalSplicesStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLik
33353333
bob ! CMD_SIGN()
33363334
inside(bob2alice.expectMsgType[CommitSigBatch]) { batch =>
33373335
assert(batch.batchSize == 3)
3338-
batch.messages.foreach(sig => bob2alice.forward(alice, sig))
3336+
bob2alice.forward(alice, batch)
33393337
}
33403338
alice2bob.expectMsgType[RevokeAndAck]
33413339
alice2bob.forward(bob)
33423340
inside(alice2bob.expectMsgType[CommitSigBatch]) { batch =>
33433341
assert(batch.batchSize == 3)
3344-
batch.messages.foreach(sig => alice2bob.forward(bob, sig))
3342+
alice2bob.forward(bob, batch)
33453343
}
33463344
bob2alice.expectMsgType[RevokeAndAck]
33473345
bob2alice.forward(alice)

eclair-core/src/test/scala/fr/acinq/eclair/io/PeerConnectionSpec.scala

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32}
2323
import fr.acinq.eclair.FeatureSupport.{Mandatory, Optional}
2424
import fr.acinq.eclair.Features._
2525
import fr.acinq.eclair.TestConstants._
26+
import fr.acinq.eclair.TestUtils.randomTxId
2627
import fr.acinq.eclair.crypto.TransportHandler
2728
import fr.acinq.eclair.io.Peer.ConnectionDown
2829
import fr.acinq.eclair.message.OnionMessages.{Recipient, buildMessage}
@@ -348,6 +349,85 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
348349
transport.expectNoMessage(100 millis)
349350
}
350351

352+
test("receive legacy batch of commit_sig messages") { f =>
353+
import f._
354+
connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer)
355+
356+
// We receive a batch of commit_sig messages from a first channel.
357+
val channelId1 = randomBytes32()
358+
val commitSigs1 = Seq(
359+
CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
360+
CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
361+
)
362+
transport.send(peerConnection, commitSigs1.head)
363+
transport.expectMsg(TransportHandler.ReadAck(commitSigs1.head))
364+
peer.expectNoMessage(100 millis)
365+
transport.send(peerConnection, commitSigs1.last)
366+
transport.expectMsg(TransportHandler.ReadAck(commitSigs1.last))
367+
peer.expectMsg(CommitSigBatch(commitSigs1))
368+
369+
// We receive a batch of commit_sig messages from a second channel.
370+
val channelId2 = randomBytes32()
371+
val commitSigs2 = Seq(
372+
CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))),
373+
CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))),
374+
CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(3))),
375+
)
376+
commitSigs2.dropRight(1).foreach(commitSig => {
377+
transport.send(peerConnection, commitSig)
378+
transport.expectMsg(TransportHandler.ReadAck(commitSig))
379+
})
380+
peer.expectNoMessage(100 millis)
381+
transport.send(peerConnection, commitSigs2.last)
382+
transport.expectMsg(TransportHandler.ReadAck(commitSigs2.last))
383+
peer.expectMsg(CommitSigBatch(commitSigs2))
384+
385+
// We receive another batch of commit_sig messages from the first channel, with unrelated messages in the batch.
386+
val commitSigs3 = Seq(
387+
CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
388+
CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
389+
)
390+
transport.send(peerConnection, commitSigs3.head)
391+
transport.expectMsg(TransportHandler.ReadAck(commitSigs3.head))
392+
val spliceLocked1 = SpliceLocked(channelId1, randomTxId())
393+
transport.send(peerConnection, spliceLocked1)
394+
transport.expectMsg(TransportHandler.ReadAck(spliceLocked1))
395+
peer.expectMsg(spliceLocked1)
396+
val spliceLocked2 = SpliceLocked(channelId2, randomTxId())
397+
transport.send(peerConnection, spliceLocked2)
398+
transport.expectMsg(TransportHandler.ReadAck(spliceLocked2))
399+
peer.expectMsg(spliceLocked2)
400+
peer.expectNoMessage(100 millis)
401+
transport.send(peerConnection, commitSigs3.last)
402+
transport.expectMsg(TransportHandler.ReadAck(commitSigs3.last))
403+
peer.expectMsg(CommitSigBatch(commitSigs3))
404+
405+
// We start receiving a batch of commit_sig messages from the first channel, interleaved with a batch from the second
406+
// channel, which is not supported.
407+
val commitSigs4 = Seq(
408+
CommitSig(channelId1, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
409+
CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
410+
CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(2))),
411+
)
412+
transport.send(peerConnection, commitSigs4.head)
413+
transport.expectMsg(TransportHandler.ReadAck(commitSigs4.head))
414+
peer.expectNoMessage(100 millis)
415+
transport.send(peerConnection, commitSigs4(1))
416+
transport.expectMsg(TransportHandler.ReadAck(commitSigs4(1)))
417+
peer.expectMsg(CommitSigBatch(commitSigs4.take(1)))
418+
transport.send(peerConnection, commitSigs4.last)
419+
transport.expectMsg(TransportHandler.ReadAck(commitSigs4.last))
420+
peer.expectMsg(CommitSigBatch(commitSigs4.tail))
421+
422+
// We receive a batch that exceeds our threshold: we process them individually.
423+
val invalidCommitSigs = (0 until 30).map(_ => CommitSig(channelId2, randomBytes64(), Nil, TlvStream(CommitSigTlv.BatchTlv(30))))
424+
invalidCommitSigs.foreach(commitSig => {
425+
transport.send(peerConnection, commitSig)
426+
transport.expectMsg(TransportHandler.ReadAck(commitSig))
427+
peer.expectMsg(commitSig)
428+
})
429+
}
430+
351431
test("react to peer's bad behavior") { f =>
352432
import f._
353433
val probe = TestProbe()

0 commit comments

Comments
 (0)