Skip to content

Commit 336a993

Browse files
committed
WIP
1 parent ef1f205 commit 336a993

File tree

11 files changed

+96
-65
lines changed

11 files changed

+96
-65
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scodec.bits.ByteVector
2626
import scodec.codecs.uint32
2727

2828
import scala.annotation.tailrec
29-
import scala.concurrent.duration.FiniteDuration
29+
import scala.concurrent.duration.{DurationLong, FiniteDuration}
3030
import scala.util.{Failure, Success, Try}
3131

3232
/**
@@ -284,6 +284,10 @@ object Sphinx extends Logging {
284284
*/
285285
case class CannotDecryptFailurePacket(unwrapped: ByteVector)
286286

287+
case class HoldTime(duration: FiniteDuration, remoteNodeId: PublicKey)
288+
289+
case class HtlcFailure(holdTimes: Seq[HoldTime], failure: Either[CannotDecryptFailurePacket, DecryptedFailurePacket])
290+
287291
object FailurePacket {
288292

289293
/**
@@ -336,41 +340,64 @@ object Sphinx extends Logging {
336340
* @return failure message if the origin of the packet could be identified and the packet decrypted, the unwrapped
337341
* failure packet otherwise.
338342
*/
339-
@tailrec
340-
def decrypt(packet: ByteVector, sharedSecrets: Seq[SharedSecret]): Either[CannotDecryptFailurePacket, DecryptedFailurePacket] = {
343+
def decrypt(packet: ByteVector, attribution_opt: Option[ByteVector], sharedSecrets: Seq[SharedSecret], hopIndex: Int = 0): HtlcFailure = {
341344
sharedSecrets match {
342-
case Nil => Left(CannotDecryptFailurePacket(packet))
345+
case Nil => HtlcFailure(Nil, Left(CannotDecryptFailurePacket(packet)))
343346
case ss :: tail =>
344347
val packet1 = wrap(packet, ss.secret)
348+
val next_opt = attribution_opt.flatMap(Attribution.unwrap(_, packet1, ss.secret, hopIndex))
345349
val um = generateKey("um", ss.secret)
346-
FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match {
347-
case Attempt.Successful(value) => Right(DecryptedFailurePacket(ss.remoteNodeId, value.value))
348-
case _ => decrypt(packet1, tail)
350+
val HtlcFailure(holdTimes, failure) = FailureMessageCodecs.failureOnionCodec(Hmac256(um)).decode(packet1.toBitVector) match {
351+
case Attempt.Successful(value) => HtlcFailure(Nil, Right(DecryptedFailurePacket(ss.remoteNodeId, value.value)))
352+
case _ => decrypt(packet1, next_opt.map(_._2), tail, hopIndex + 1)
349353
}
354+
HtlcFailure(next_opt.map(n => HoldTime(n._1, ss.remoteNodeId) +: holdTimes).getOrElse(Nil), failure)
350355
}
351356
}
352357

353-
def attribution(previousAttribution_opt: Option[ByteVector], reason: ByteVector, holdTime: FiniteDuration, sharedSecret: ByteVector32): ByteVector = {
354-
val previousAttribution = previousAttribution_opt.getOrElse(ByteVector.low(920))
355-
val previousHmacs = (0 until 19).map(i => (1 until (20 - i)).map(j => {
356-
val start = 80 + (20 * i - (i * (i - 1)) / 2 + j) * 4
357-
previousAttribution.slice(start, start + 4)
358-
}))
359-
val mac = Hmac256(generateKey("um", sharedSecret))
360-
val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take(19 * 4)
361-
val hmacs = computeHmacs(mac, reason, holdTimes, previousHmacs) +: previousHmacs
362-
val key = generateKey("ammagext", sharedSecret)
363-
val stream = generateStream(key, 920)
364-
(holdTimes ++ ByteVector.concat(hmacs.map(ByteVector.concat(_)))) xor stream
365-
}
358+
object Attribution {
359+
private def cipher(bytes: ByteVector, sharedSecret: ByteVector32): ByteVector = {
360+
val key = generateKey("ammagext", sharedSecret)
361+
val stream = generateStream(key, 920)
362+
bytes xor stream
363+
}
364+
365+
private def getHmacs(bytes: ByteVector): Seq[Seq[ByteVector]] =
366+
(0 until 20).map(i => (0 until (20 - i)).map(j => {
367+
val start = (20 + 20 * i - (i * (i - 1)) / 2 + j) * 4
368+
bytes.slice(start, start + 4)
369+
}))
370+
371+
private def computeHmacs(mac: Mac32, reason: ByteVector, holdTimes: ByteVector, hmacs: Seq[Seq[ByteVector]], minNumHop: Int): Seq[ByteVector] = {
372+
(minNumHop until 20).map(i => {
373+
val y = 20 - i
374+
mac.mac(reason ++
375+
holdTimes.take(y * 4) ++
376+
ByteVector.concat((0 until y - 1).map(j => hmacs(j)(i)))).bytes.take(4)
377+
})
378+
}
379+
380+
def create(previousAttribution_opt: Option[ByteVector], reason: ByteVector, holdTime: FiniteDuration, sharedSecret: ByteVector32): ByteVector = {
381+
val previousAttribution = previousAttribution_opt.getOrElse(ByteVector.low(920))
382+
val previousHmacs = getHmacs(previousAttribution).dropRight(1).map(_.drop(1))
383+
val mac = Hmac256(generateKey("um", sharedSecret))
384+
val holdTimes = uint32.encode(holdTime.toMillis).require.bytes ++ previousAttribution.take(19 * 4)
385+
val hmacs = computeHmacs(mac, reason, holdTimes, previousHmacs, 0) +: previousHmacs
386+
cipher(holdTimes ++ ByteVector.concat(hmacs.map(ByteVector.concat(_))), sharedSecret)
387+
}
366388

367-
private def computeHmacs(mac: Mac32, reason: ByteVector, holdTimes: ByteVector, hmacs: Seq[Seq[ByteVector]]): Seq[ByteVector] = {
368-
(0 until 20).map(i => {
369-
val y = 20 - i
370-
mac.mac(reason ++
371-
holdTimes.take(y * 4) ++
372-
ByteVector.concat((0 until y - 1).map(j => hmacs(j)(i)))).bytes.take(4)
373-
})
389+
def unwrap(encrypted: ByteVector, reason: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Option[(FiniteDuration, ByteVector)] = {
390+
val bytes = cipher(encrypted, sharedSecret)
391+
val holdTime = uint32.decode(bytes.take(4).bits).require.value.milliseconds
392+
val hmacs = getHmacs(bytes)
393+
val mac = Hmac256(generateKey("um", sharedSecret))
394+
if (computeHmacs(mac, reason, bytes.take(20 * 4), hmacs.drop(1), minNumHop) == hmacs.head.drop(minNumHop)) {
395+
val unwraped = bytes.slice(4, 20 * 4) ++ ByteVector.low(4) ++ ByteVector.concat((hmacs.drop(1) :+ Seq()).map(s => ByteVector.low(4) ++ ByteVector.concat(s)))
396+
Some(holdTime, unwraped)
397+
} else {
398+
None
399+
}
400+
}
374401
}
375402
}
376403

eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ object FailureSummary {
250250
def apply(f: PaymentFailure): FailureSummary = f match {
251251
case LocalFailure(_, route, t) => FailureSummary(FailureType.LOCAL, t.getMessage, route.map(h => HopSummary(h)).toList, route.headOption.map(_.nodeId))
252252
case RemoteFailure(_, route, e) => FailureSummary(FailureType.REMOTE, e.failureMessage.message, route.map(h => HopSummary(h)).toList, Some(e.originNode))
253-
case UnreadableRemoteFailure(_, route, _) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList, None)
253+
case UnreadableRemoteFailure(_, route, _, _) => FailureSummary(FailureType.UNREADABLE_REMOTE, "could not decrypt failure onion", route.map(h => HopSummary(h)).toList, None)
254254
}
255255
}
256256

eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package fr.acinq.eclair.payment
1919
import fr.acinq.bitcoin.scalacompat.ByteVector32
2020
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
2121
import fr.acinq.eclair.crypto.Sphinx
22+
import fr.acinq.eclair.crypto.Sphinx.HoldTime
2223
import fr.acinq.eclair.payment.Invoice.ExtraEdge
2324
import fr.acinq.eclair.payment.send.PaymentError.RetryExhausted
2425
import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentConfig
@@ -150,7 +151,7 @@ case class LocalFailure(amount: MilliSatoshi, route: Seq[Hop], t: Throwable) ext
150151
case class RemoteFailure(amount: MilliSatoshi, route: Seq[Hop], e: Sphinx.DecryptedFailurePacket) extends PaymentFailure
151152

152153
/** A remote node failed the payment but we couldn't decrypt the failure (e.g. a malicious node tampered with the message). */
153-
case class UnreadableRemoteFailure(amount: MilliSatoshi, route: Seq[Hop], failurePacket: ByteVector) extends PaymentFailure
154+
case class UnreadableRemoteFailure(amount: MilliSatoshi, route: Seq[Hop], failurePacket: ByteVector, holdTimes: Seq[HoldTime]) extends PaymentFailure
154155

155156
object PaymentFailure {
156157

@@ -235,13 +236,14 @@ object PaymentFailure {
235236
}
236237
case RemoteFailure(_, hops, Sphinx.DecryptedFailurePacket(nodeId, _)) =>
237238
ignoreNodeOutgoingEdge(nodeId, hops, ignore)
238-
case UnreadableRemoteFailure(_, hops, _) =>
239+
case UnreadableRemoteFailure(_, hops, _, holdTimes) => // TODO
239240
// We don't know which node is sending garbage, let's blacklist all nodes except:
241+
// - the nodes that returned attribution data (except the last one)
240242
// - the one we are directly connected to: it would be too restrictive for retries
241243
// - the final recipient: they have no incentive to send garbage since they want that payment
242244
// - the introduction point of a blinded route: we don't want a node before the blinded path to force us to ignore that blinded path
243245
// - the trampoline node: we don't want a node before the trampoline node to force us to ignore that trampoline node
244-
val blacklist = hops.collect { case hop: ChannelHop => hop }.map(_.nextNodeId).drop(1).dropRight(1).toSet
246+
val blacklist = hops.collect { case hop: ChannelHop => hop }.map(_.nextNodeId).drop(1 max (holdTimes.length - 1)).dropRight(1).toSet
245247
ignore ++ blacklist
246248
case LocalFailure(_, hops, _) => hops.headOption match {
247249
case Some(hop: ChannelHop) =>

eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,15 @@ object OutgoingPaymentPacket {
351351
reason match {
352352
case FailureReason.EncryptedDownstreamFailure(packet, attribution) =>
353353
val tlvs: TlvStream[UpdateFailHtlcTlv] = if (useAttributableFailures) {
354-
TlvStream(UpdateFailHtlcTlv.AttributionData(Sphinx.FailurePacket.attribution(attribution, packet, holdTime, sharedSecret)))
354+
TlvStream(UpdateFailHtlcTlv.AttributionData(Sphinx.FailurePacket.Attribution.create(attribution, packet, holdTime, sharedSecret)))
355355
} else {
356356
TlvStream.empty
357357
}
358358
(Sphinx.FailurePacket.wrap(packet, sharedSecret), tlvs)
359359
case FailureReason.LocalFailure(failure) =>
360360
val packet = Sphinx.FailurePacket.create(sharedSecret, failure)
361361
val tlvs: TlvStream[UpdateFailHtlcTlv] = if (useAttributableFailures) {
362-
TlvStream(UpdateFailHtlcTlv.AttributionData(Sphinx.FailurePacket.attribution(None, packet, holdTime, sharedSecret)))
362+
TlvStream(UpdateFailHtlcTlv.AttributionData(Sphinx.FailurePacket.Attribution.create(None, packet, holdTime, sharedSecret)))
363363
} else {
364364
TlvStream.empty
365365
}

eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/OnTheFlyFunding.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ object OnTheFlyFunding {
107107
case f: FailureReason.EncryptedDownstreamFailure =>
108108
// In the trampoline case, we currently ignore downstream failures: we should add dedicated failures to
109109
// the BOLTs to better handle those cases.
110-
Sphinx.FailurePacket.decrypt(f.packet, onionSharedSecrets) match {
110+
Sphinx.FailurePacket.decrypt(f.packet, f.attribution_opt, onionSharedSecrets).failure match {
111111
case Left(Sphinx.CannotDecryptFailurePacket(_)) =>
112112
log.warning("couldn't decrypt downstream on-the-fly funding failure")
113113
case Right(f) =>

eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,13 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A
164164

165165
private def handleRemoteFail(d: WaitingForComplete, fail: UpdateFailHtlc) = {
166166
import d._
167-
((Sphinx.FailurePacket.decrypt(fail.reason, sharedSecrets) match {
167+
val htlcFailure = Sphinx.FailurePacket.decrypt(fail.reason, fail.attribution_opt, sharedSecrets)
168+
((htlcFailure.failure match {
168169
case success@Right(e) =>
169170
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(RemoteFailure(request.amount, Nil, e))).increment()
170171
success
171172
case failure@Left(e) =>
172-
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(UnreadableRemoteFailure(request.amount, Nil, e.unwrapped))).increment()
173+
Metrics.PaymentError.withTag(Tags.Failure, Tags.FailureType(UnreadableRemoteFailure(request.amount, Nil, e.unwrapped, htlcFailure.holdTimes))).increment()
173174
failure
174175
}) match {
175176
case res@Right(Sphinx.DecryptedFailurePacket(nodeId, failureMessage)) =>
@@ -217,13 +218,13 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A
217218
RemoteFailure(request.amount, route.fullRoute, e)
218219
case Left(Sphinx.CannotDecryptFailurePacket(unwrapped)) =>
219220
log.warning(s"cannot parse returned error ${fail.reason.toHex} with sharedSecrets=$sharedSecrets: unwrapped=$unwrapped")
220-
UnreadableRemoteFailure(request.amount, route.fullRoute, unwrapped)
221+
UnreadableRemoteFailure(request.amount, route.fullRoute, unwrapped, htlcFailure.holdTimes)
221222
}
222223
log.warning(s"too many failed attempts, failing the payment")
223224
myStop(request, Left(PaymentFailed(id, paymentHash, failures :+ failure)))
224225
case Left(Sphinx.CannotDecryptFailurePacket(unwrapped)) =>
225226
log.warning(s"cannot parse returned error: unwrapped=$unwrapped, route=${route.printNodes()}")
226-
val failure = UnreadableRemoteFailure(request.amount, route.fullRoute, unwrapped)
227+
val failure = UnreadableRemoteFailure(request.amount, route.fullRoute, unwrapped, htlcFailure.holdTimes)
227228
retry(failure, d)
228229
case Right(e@Sphinx.DecryptedFailurePacket(nodeId, failureMessage: Node)) =>
229230
log.info(s"received 'Node' type error message from nodeId=$nodeId, trying to route around it (failure=$failureMessage)")

0 commit comments

Comments
 (0)