Skip to content

Commit fee8737

Browse files
committed
Refactor channel_type validation
As suggested by @sstone.
1 parent 3693e12 commit fee8737

File tree

7 files changed

+35
-45
lines changed

7 files changed

+35
-45
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/channel/ChannelFeatures.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package fr.acinq.eclair.channel
1818

19+
import fr.acinq.bitcoin.scalacompat.ByteVector32
1920
import fr.acinq.eclair.transactions.Transactions._
2021
import fr.acinq.eclair.{ChannelTypeFeature, FeatureSupport, Features, InitFeature, PermanentChannelFeature}
2122

@@ -140,15 +141,12 @@ object ChannelTypes {
140141
def fromFeatures(features: Features[InitFeature]): ChannelType = features2ChannelType.getOrElse(features, UnsupportedChannelType(features))
141142

142143
/** Check if a given channel type is compatible with our features. */
143-
def areCompatible(localFeatures: Features[InitFeature], remoteChannelType: ChannelType): Option[SupportedChannelType] = remoteChannelType match {
144-
case _: UnsupportedChannelType => None
144+
def areCompatible(channelId: ByteVector32, localFeatures: Features[InitFeature], remoteChannelType_opt: Option[ChannelType]): Either[ChannelException, SupportedChannelType] = remoteChannelType_opt match {
145+
case None => Left(MissingChannelType(channelId))
146+
case Some(channelType: UnsupportedChannelType) => Left(InvalidChannelType(channelId, channelType))
145147
// We ensure that we support the features necessary for this channel type.
146-
case proposedChannelType: SupportedChannelType =>
147-
if (proposedChannelType.features.forall(f => localFeatures.hasFeature(f))) {
148-
Some(proposedChannelType)
149-
} else {
150-
None
151-
}
148+
case Some(proposedChannelType: SupportedChannelType) if proposedChannelType.features.forall(f => localFeatures.hasFeature(f)) => Right(proposedChannelType)
149+
case Some(proposedChannelType: SupportedChannelType) => Left(InvalidChannelType(channelId, proposedChannelType))
152150
}
153151

154152
}

eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ object Helpers {
9999
}
100100

101101
/** Called by the fundee of a single-funded channel. */
102-
def validateParamsSingleFundedFundee(nodeParams: NodeParams, channelType: SupportedChannelType, localFeatures: Features[InitFeature], open: OpenChannel, remoteNodeId: PublicKey, remoteFeatures: Features[InitFeature]): Either[ChannelException, (ChannelFeatures, Option[ByteVector])] = {
102+
def validateParamsSingleFundedFundee(nodeParams: NodeParams, localFeatures: Features[InitFeature], open: OpenChannel, remoteNodeId: PublicKey, remoteFeatures: Features[InitFeature]): Either[ChannelException, (ChannelFeatures, Option[ByteVector])] = {
103103
// BOLT #2: if the chain_hash value, within the open_channel, message is set to a hash of a chain that is unknown to the receiver:
104104
// MUST reject the channel.
105105
if (nodeParams.chainHash != open.chainHash) return Left(InvalidChainHash(open.temporaryChannelId, local = nodeParams.chainHash, remote = open.chainHash))
@@ -133,6 +133,10 @@ object Helpers {
133133
return Left(ChannelReserveNotMet(open.temporaryChannelId, toLocalMsat, toRemoteMsat, open.channelReserveSatoshis))
134134
}
135135

136+
val channelType = ChannelTypes.areCompatible(open.temporaryChannelId, localFeatures, open.channelType_opt) match {
137+
case Left(f) => return Left(f)
138+
case Right(proposedChannelType) => proposedChannelType
139+
}
136140
val channelFeatures = ChannelFeatures(channelType, localFeatures, remoteFeatures, open.channelFlags.announceChannel)
137141
channelType.commitmentFormat match {
138142
case _: SimpleTaprootChannelCommitmentFormat => if (open.commitNonce_opt.isEmpty) return Left(MissingCommitNonce(open.temporaryChannelId, TxId(ByteVector32.Zeroes), commitmentNumber = 0))
@@ -154,7 +158,6 @@ object Helpers {
154158

155159
/** Called by the non-initiator of a dual-funded channel. */
156160
def validateParamsDualFundedNonInitiator(nodeParams: NodeParams,
157-
channelType: SupportedChannelType,
158161
open: OpenDualFundedChannel,
159162
fundingScript: ByteVector,
160163
remoteNodeId: PublicKey,
@@ -184,6 +187,10 @@ object Helpers {
184187
if (open.dustLimit < Channel.MIN_DUST_LIMIT) return Left(DustLimitTooSmall(open.temporaryChannelId, open.dustLimit, Channel.MIN_DUST_LIMIT))
185188
if (open.dustLimit > nodeParams.channelConf.maxRemoteDustLimit) return Left(DustLimitTooLarge(open.temporaryChannelId, open.dustLimit, nodeParams.channelConf.maxRemoteDustLimit))
186189

190+
val channelType = ChannelTypes.areCompatible(open.temporaryChannelId, localFeatures, open.channelType_opt) match {
191+
case Left(f) => return Left(f)
192+
case Right(proposedChannelType) => proposedChannelType
193+
}
187194
val channelFeatures = ChannelFeatures(channelType, localFeatures, remoteFeatures, open.channelFlags.announceChannel)
188195

189196
// BOLT #2: The receiving node MUST fail the channel if: it considers feerate_per_kw too small for timely processing or unreasonably large.
@@ -196,22 +203,19 @@ object Helpers {
196203
} yield (channelFeatures, script_opt, willFund_opt)
197204
}
198205

199-
private def validateChannelType(channelId: ByteVector32, channelType: SupportedChannelType, openChannelType_opt: Option[ChannelType], acceptChannelType_opt: Option[ChannelType]): Option[ChannelException] = {
200-
if (openChannelType_opt.isEmpty || acceptChannelType_opt.isEmpty) {
201-
Some(MissingChannelType(channelId))
202-
} else if (!openChannelType_opt.contains(channelType) || !acceptChannelType_opt.contains(channelType)) {
203-
Some(InvalidChannelType(channelId, acceptChannelType_opt.get))
204-
} else {
205-
// we agree on channel type
206-
None
206+
private def validateChannelTypeInitiator(channelId: ByteVector32, openChannelType_opt: Option[ChannelType], acceptChannelType_opt: Option[ChannelType]): Either[ChannelException, SupportedChannelType] = {
207+
(openChannelType_opt, acceptChannelType_opt) match {
208+
case (Some(proposed: SupportedChannelType), Some(received)) if proposed == received => Right(proposed)
209+
case (Some(_), Some(received)) => Left(InvalidChannelType(channelId, received))
210+
case _ => Left(MissingChannelType(channelId))
207211
}
208212
}
209213

210214
/** Called by the funder of a single-funded channel. */
211-
def validateParamsSingleFundedFunder(nodeParams: NodeParams, channelType: SupportedChannelType, localFeatures: Features[InitFeature], remoteFeatures: Features[InitFeature], open: OpenChannel, accept: AcceptChannel): Either[ChannelException, (ChannelFeatures, Option[ByteVector])] = {
212-
validateChannelType(open.temporaryChannelId, channelType, open.channelType_opt, accept.channelType_opt) match {
213-
case Some(t) => return Left(t)
214-
case None => // we agree on channel type
215+
def validateParamsSingleFundedFunder(nodeParams: NodeParams, localFeatures: Features[InitFeature], remoteFeatures: Features[InitFeature], open: OpenChannel, accept: AcceptChannel): Either[ChannelException, (ChannelFeatures, Option[ByteVector])] = {
216+
val channelType = validateChannelTypeInitiator(open.temporaryChannelId, open.channelType_opt, accept.channelType_opt) match {
217+
case Left(t) => return Left(t)
218+
case Right(channelType) => channelType
215219
}
216220

217221
if (accept.maxAcceptedHtlcs > Channel.MAX_ACCEPTED_HTLCS) return Left(InvalidMaxAcceptedHtlcs(accept.temporaryChannelId, accept.maxAcceptedHtlcs, Channel.MAX_ACCEPTED_HTLCS))
@@ -248,14 +252,13 @@ object Helpers {
248252
/** Called by the initiator of a dual-funded channel. */
249253
def validateParamsDualFundedInitiator(nodeParams: NodeParams,
250254
remoteNodeId: PublicKey,
251-
channelType: SupportedChannelType,
252255
localFeatures: Features[InitFeature],
253256
remoteFeatures: Features[InitFeature],
254257
open: OpenDualFundedChannel,
255258
accept: AcceptDualFundedChannel): Either[ChannelException, (ChannelFeatures, Option[ByteVector], Option[LiquidityAds.Purchase])] = {
256-
validateChannelType(open.temporaryChannelId, channelType, open.channelType_opt, accept.channelType_opt) match {
257-
case Some(t) => return Left(t)
258-
case None => // we agree on channel type
259+
val channelType = validateChannelTypeInitiator(open.temporaryChannelId, open.channelType_opt, accept.channelType_opt) match {
260+
case Left(t) => return Left(t)
261+
case Right(channelType) => channelType
259262
}
260263

261264
// BOLT #2: Channel funding limits

eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/ChannelOpenDualFunded.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
141141
case Event(open: OpenDualFundedChannel, d: DATA_WAIT_FOR_OPEN_DUAL_FUNDED_CHANNEL) =>
142142
val localFundingPubkey = channelKeys.fundingKey(fundingTxIndex = 0).publicKey
143143
val fundingScript = Transactions.makeFundingScript(localFundingPubkey, open.fundingPubkey, d.init.channelType.commitmentFormat).pubkeyScript
144-
Helpers.validateParamsDualFundedNonInitiator(nodeParams, d.init.channelType, open, fundingScript, remoteNodeId, d.init.localChannelParams.initFeatures, d.init.remoteInit.features, d.init.fundingContribution_opt) match {
144+
Helpers.validateParamsDualFundedNonInitiator(nodeParams, open, fundingScript, remoteNodeId, d.init.localChannelParams.initFeatures, d.init.remoteInit.features, d.init.fundingContribution_opt) match {
145145
case Left(t) => handleLocalError(t, d, Some(open))
146146
case Right((channelFeatures, remoteShutdownScript, willFund_opt)) =>
147147
context.system.eventStream.publish(ChannelCreated(self, peer, remoteNodeId, isOpener = false, open.temporaryChannelId, open.commitmentFeerate, Some(open.fundingFeerate)))
@@ -224,7 +224,7 @@ trait ChannelOpenDualFunded extends DualFundingHandlers with ErrorHandlers {
224224

225225
when(WAIT_FOR_ACCEPT_DUAL_FUNDED_CHANNEL)(handleExceptions {
226226
case Event(accept: AcceptDualFundedChannel, d: DATA_WAIT_FOR_ACCEPT_DUAL_FUNDED_CHANNEL) =>
227-
Helpers.validateParamsDualFundedInitiator(nodeParams, remoteNodeId, d.init.channelType, d.init.localChannelParams.initFeatures, d.init.remoteInit.features, d.lastSent, accept) match {
227+
Helpers.validateParamsDualFundedInitiator(nodeParams, remoteNodeId, d.init.localChannelParams.initFeatures, d.init.remoteInit.features, d.lastSent, accept) match {
228228
case Left(t) =>
229229
d.init.replyTo ! OpenChannelResponse.Rejected(t.getMessage)
230230
handleLocalError(t, d, Some(accept))

eclair-core/src/main/scala/fr/acinq/eclair/channel/fsm/ChannelOpenSingleFunded.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ trait ChannelOpenSingleFunded extends SingleFundingHandlers with ErrorHandlers {
112112

113113
when(WAIT_FOR_OPEN_CHANNEL)(handleExceptions {
114114
case Event(open: OpenChannel, d: DATA_WAIT_FOR_OPEN_CHANNEL) =>
115-
Helpers.validateParamsSingleFundedFundee(nodeParams, d.initFundee.channelType, d.initFundee.localChannelParams.initFeatures, open, remoteNodeId, d.initFundee.remoteInit.features) match {
115+
Helpers.validateParamsSingleFundedFundee(nodeParams, d.initFundee.localChannelParams.initFeatures, open, remoteNodeId, d.initFundee.remoteInit.features) match {
116116
case Left(t) => handleLocalError(t, d, Some(open))
117117
case Right((channelFeatures, remoteShutdownScript)) =>
118118
context.system.eventStream.publish(ChannelCreated(self, peer, remoteNodeId, isOpener = false, open.temporaryChannelId, open.feeratePerKw, None))
@@ -168,7 +168,7 @@ trait ChannelOpenSingleFunded extends SingleFundingHandlers with ErrorHandlers {
168168

169169
when(WAIT_FOR_ACCEPT_CHANNEL)(handleExceptions {
170170
case Event(accept: AcceptChannel, d: DATA_WAIT_FOR_ACCEPT_CHANNEL) =>
171-
Helpers.validateParamsSingleFundedFunder(nodeParams, d.initFunder.channelType, d.initFunder.localChannelParams.initFeatures, d.initFunder.remoteInit.features, d.lastSent, accept) match {
171+
Helpers.validateParamsSingleFundedFunder(nodeParams, d.initFunder.localChannelParams.initFeatures, d.initFunder.remoteInit.features, d.lastSent, accept) match {
172172
case Left(t) =>
173173
d.initFunder.replyTo ! OpenChannelResponse.Rejected(t.getMessage)
174174
handleLocalError(t, d, Some(accept))

eclair-core/src/main/scala/fr/acinq/eclair/io/OpenChannelInterceptor.scala

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private class OpenChannelInterceptor(peer: ActorRef[Any],
132132
}
133133

134134
private def sanityCheckNonInitiator(request: OpenChannelNonInitiator): Behavior[Command] = {
135-
validateRemoteChannelType(request.temporaryChannelId, request.channelType_opt, request.localFeatures) match {
135+
ChannelTypes.areCompatible(request.temporaryChannelId, request.localFeatures, request.channelType_opt) match {
136136
case Right(channelType) =>
137137
val dualFunded = Features.canUseFeature(request.localFeatures, request.remoteFeatures, Features.DualFunding)
138138
val upfrontShutdownScript = Features.canUseFeature(request.localFeatures, request.remoteFeatures, Features.UpfrontShutdownScript)
@@ -274,17 +274,6 @@ private class OpenChannelInterceptor(peer: ActorRef[Any],
274274
}
275275
}
276276

277-
private def validateRemoteChannelType(temporaryChannelId: ByteVector32, remoteChannelType_opt: Option[ChannelType], localFeatures: Features[InitFeature]): Either[ChannelException, SupportedChannelType] = {
278-
remoteChannelType_opt match {
279-
// remote explicitly specifies a channel type: we check whether we want to allow it
280-
case Some(remoteChannelType) => ChannelTypes.areCompatible(localFeatures, remoteChannelType) match {
281-
case Some(acceptedChannelType) => Right(acceptedChannelType)
282-
case None => Left(InvalidChannelType(temporaryChannelId, remoteChannelType))
283-
}
284-
case None => Left(MissingChannelType(temporaryChannelId))
285-
}
286-
}
287-
288277
private def createLocalParams(nodeParams: NodeParams, initFeatures: Features[InitFeature], upfrontShutdownScript: Boolean, channelType: SupportedChannelType, isChannelOpener: Boolean, paysCommitTxFees: Boolean, dualFunded: Boolean, fundingAmount: Satoshi): LocalChannelParams = {
289278
makeChannelParams(
290279
nodeParams, initFeatures,

eclair-core/src/test/scala/fr/acinq/eclair/channel/states/ChannelStateTestsHelperMethods.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ trait ChannelStateTestsBase extends Assertions with Eventually {
260260
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.DualFunding))(_.updated(Features.DualFunding, FeatureSupport.Optional))
261261
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.SimpleClose))(_.updated(Features.SimpleClose, FeatureSupport.Optional))
262262
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.AnchorOutputsPhoenix))(_.removed(Features.AnchorOutputsZeroFeeHtlcTx).updated(Features.AnchorOutputs, FeatureSupport.Optional))
263-
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.OptionSimpleTaprootPhoenix))(_.removed(Features.SimpleTaprootChannelsStaging).updated(Features.SimpleTaprootChannelsPhoenix, FeatureSupport.Optional))
263+
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.OptionSimpleTaprootPhoenix))(_.removed(Features.SimpleTaprootChannelsStaging).updated(Features.SimpleTaprootChannelsPhoenix, FeatureSupport.Optional).updated(Features.PhoenixZeroReserve, FeatureSupport.Optional))
264264
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.OptionSimpleTaproot))(_.updated(Features.SimpleTaprootChannelsStaging, FeatureSupport.Optional))
265265
)
266266
val nodeParamsB1 = nodeParamsB.copy(features = nodeParamsB.features
@@ -272,7 +272,7 @@ trait ChannelStateTestsBase extends Assertions with Eventually {
272272
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.SimpleClose))(_.updated(Features.SimpleClose, FeatureSupport.Optional))
273273
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.DisableSplice))(_.removed(Features.SplicePrototype))
274274
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.AnchorOutputsPhoenix))(_.removed(Features.AnchorOutputsZeroFeeHtlcTx).updated(Features.AnchorOutputs, FeatureSupport.Optional))
275-
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.OptionSimpleTaprootPhoenix))(_.removed(Features.SimpleTaprootChannelsStaging).updated(Features.SimpleTaprootChannelsPhoenix, FeatureSupport.Optional))
275+
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.OptionSimpleTaprootPhoenix))(_.removed(Features.SimpleTaprootChannelsStaging).updated(Features.SimpleTaprootChannelsPhoenix, FeatureSupport.Optional).updated(Features.PhoenixZeroReserve, FeatureSupport.Optional))
276276
.modify(_.activated).usingIf(tags.contains(ChannelStateTestsTags.OptionSimpleTaproot))(_.updated(Features.SimpleTaprootChannelsStaging, FeatureSupport.Optional))
277277
)
278278
(nodeParamsA1, nodeParamsB1)

eclair-core/src/test/scala/fr/acinq/eclair/channel/states/a/WaitForOpenChannelStateSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ class WaitForOpenChannelStateSpec extends TestKitBaseClass with FixtureAnyFunSui
259259
test("recv OpenChannel (empty upfront shutdown script)", Tag(ChannelStateTestsTags.UpfrontShutdownScript)) { f =>
260260
import f._
261261
val open = alice2bob.expectMsgType[OpenChannel]
262-
val open1 = open.copy(tlvStream = TlvStream(ChannelTlv.UpfrontShutdownScriptTlv(ByteVector.empty)))
262+
val open1 = open.copy(tlvStream = TlvStream(open.tlvStream.records.filterNot(_.isInstanceOf[ChannelTlv.UpfrontShutdownScriptTlv]) + ChannelTlv.UpfrontShutdownScriptTlv(ByteVector.empty)))
263263
alice2bob.forward(bob, open1)
264264
awaitCond(bob.stateName == WAIT_FOR_FUNDING_CREATED)
265265
assert(bob.stateData.asInstanceOf[DATA_WAIT_FOR_FUNDING_CREATED].channelParams.remoteParams.upfrontShutdownScript_opt.isEmpty)

0 commit comments

Comments
 (0)